"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license.
"""
import json

import pytest
import requests_mock
from core.redis import redis_healthcheck
from ml.models import MLBackend
from projects.models import Project
from tasks.models import Annotation, AnnotationDraft, Prediction, Task
from users.models import User

from .utils import make_project

_project_for_text_choices_onto_A_B_classes = dict(
    title='Test',
    label_config="""
        <View>
          <Text name="meta_info" value="$meta_info"></Text>
          <Text name="text" value="$text"></Text>
          <Choices name="text_class" toName="text" choice="single">
            <Choice value="class_A"></Choice>
            <Choice value="class_B"></Choice>
          </Choices>
        </View>""",
)

_2_tasks_with_textA_and_textB = [
    {'meta_info': 'meta info A', 'text': 'text A'},
    {'meta_info': 'meta info B', 'text': 'text B'},
]

_2_prediction_results_for_textA_textB = [
    {
        'result': [
            {
                'from_name': 'text_class',
                'to_name': 'text',
                'type': 'labels',
                'value': {'labels': ['class_A'], 'start': 0, 'end': 1},
            }
        ],
        'score': 0.95,
    },
    {
        'result': [
            {
                'from_name': 'text_class',
                'to_name': 'text',
                'type': 'labels',
                'value': {'labels': ['class_B'], 'start': 0, 'end': 1},
            }
        ],
        'score': 0.59,
    },
]


def run_task_predictions(client, project, mocker):
    class TestJob:
        def __init__(self, job_id):
            self.id = job_id

    m = MLBackend.objects.filter(project=project.id).filter(url='http://localhost:8999').first()
    return client.post(f'/api/ml/{m.id}/predict')


@pytest.mark.skipif(not redis_healthcheck(), reason='Starting predictions requires Redis server enabled')
@pytest.mark.parametrize(
    'project_config, tasks, annotations, prediction_results, log_messages, model_version_in_request, use_ground_truth',
    [
        (
            # project config
            _project_for_text_choices_onto_A_B_classes,
            # tasks
            _2_tasks_with_textA_and_textB,
            # annotations
            [
                dict(
                    result=[
                        {
                            'from_name': 'text_class',
                            'to_name': 'text',
                            'type': 'labels',
                            'value': {'labels': ['class_A'], 'start': 0, 'end': 1},
                        }
                    ],
                    ground_truth=True,
                ),
                dict(
                    result=[
                        {
                            'from_name': 'text_class',
                            'to_name': 'text',
                            'type': 'labels',
                            'value': {'labels': ['class_B'], 'start': 0, 'end': 1},
                        }
                    ],
                    ground_truth=True,
                ),
            ],
            # prediction results
            _2_prediction_results_for_textA_textB,
            # log messages
            None,
            # model version in request
            '12345',
            False,
        ),
        (
            # project config
            _project_for_text_choices_onto_A_B_classes,
            # tasks
            _2_tasks_with_textA_and_textB,
            # annotations
            [
                dict(
                    result=[
                        {
                            'from_name': 'text_class',
                            'to_name': 'text',
                            'type': 'labels',
                            'value': {'labels': ['class_A'], 'start': 0, 'end': 1},
                        }
                    ],
                    ground_truth=True,
                ),
                dict(
                    result=[
                        {
                            'from_name': 'text_class',
                            'to_name': 'text',
                            'type': 'labels',
                            'value': {'labels': ['class_B'], 'start': 0, 'end': 1},
                        }
                    ],
                    ground_truth=True,
                ),
            ],
            # prediction results
            _2_prediction_results_for_textA_textB,
            # log messages
            None,
            # model version in request
            '12345',
            True,
        ),
    ],
)
@pytest.mark.django_db
def test_predictions(
    business_client,
    project_config,
    tasks,
    annotations,
    prediction_results,
    log_messages,
    model_version_in_request,
    use_ground_truth,
    mocker,
):

    # create project with predefined task set
    project = make_project(project_config, business_client.user)

    for task, annotation in zip(tasks, annotations):
        t = Task.objects.create(data=task, project=project)
        if use_ground_truth:
            Annotation.objects.create(task=t, **annotation)

    # run prediction
    with requests_mock.Mocker() as m:
        m.post('http://localhost:8999/setup', text=json.dumps({'model_version': model_version_in_request}))
        m.post(
            'http://localhost:8999/predict',
            text=json.dumps({'results': prediction_results[:1], 'model_version': model_version_in_request}),
        )
        r = run_task_predictions(business_client, project, mocker)
        assert r.status_code == 200
        assert m.called

    # check whether stats are created
    predictions = Prediction.objects.all()
    project = Project.objects.get(id=project.id)
    ml_backend = MLBackend.objects.get(url='http://localhost:8999')

    assert predictions.count() == len(tasks)

    for actual_prediction, expected_prediction_result in zip(predictions, prediction_results):
        assert actual_prediction.result == prediction_results[0]['result']
        assert actual_prediction.score == prediction_results[0]['score']
        assert ml_backend.model_version == actual_prediction.model_version


@pytest.mark.skipif(not redis_healthcheck(), reason='Starting predictions requires Redis server enabled')
@pytest.mark.parametrize(
    'test_name, project_config, setup_returns_model_version, tasks, annotations, '
    'input_predictions, prediction_call_count, num_project_stats, num_ground_truth_in_stats, '
    'num_ground_truth_fit_predictions',
    [
        (
            # test name just for reference
            'All predictions are outdated, project.model_version is outdated too',
            # project config: contains old model version
            dict(
                title='Test',
                model_version='12345_old',
                label_config="""
                <View>
                  <Text name="txt" value="$text"></Text>
                  <Choices name="cls" toName="txt" choice="single">
                    <Choice value="class_A"></Choice>
                    <Choice value="class_B"></Choice>
                  </Choices>
                </View>""",
            ),
            # setup API returns this model version
            '12345',
            # task data
            [{'text': 'text A'}, {'text': 'text B'}],
            # annotations: there is no any annotations
            [None, None],
            # predictions: 2 predictions are from old model version
            [
                {
                    'result': [
                        {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_A']}}
                    ],
                    'score': 0.95,
                    'model_version': '12345_old',
                },
                {
                    'result': [
                        {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_B']}}
                    ],
                    'score': 0.59,
                    'model_version': '12345_old',
                },
            ],
            # prediction call count is 2 for both tasks with old predictions
            2,
            # ground_truth stats
            0,
            0,
            0,
        ),
        (
            # test name just for reference
            'All predictions are up-to-date',
            # project config: contains actual model version
            dict(
                title='Test',
                model_version='12345_old',
                label_config="""
        <View>
          <Text name="txt" value="$text"></Text>
          <Choices name="cls" toName="txt" choice="single">
            <Choice value="class_A"></Choice>
            <Choice value="class_B"></Choice>
          </Choices>
        </View>""",
            ),
            # setup API returns this model version
            '12345',
            # task data
            [{'text': 'text A'}, {'text': 'text B'}],
            # annotations: there is no any annotations
            [None, None],
            # predictions: 2 predictions are from old model version
            [
                {
                    'result': [
                        {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_A']}}
                    ],
                    'score': 0.95,
                    'model_version': '12345',
                },
                {
                    'result': [
                        {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_B']}}
                    ],
                    'score': 0.59,
                    'model_version': '12345',
                },
            ],
            # prediction call count is 0 since predictions are up to date
            0,
            # ground_truth stats
            0,
            0,
            0,
        ),
        (
            # test name just for reference
            'Some predictions are outdated, other are up-to-date. project.model_version is up-to-date',
            # project config: contains actual model version
            dict(
                title='Test',
                model_version='12345',
                label_config="""
        <View>
          <Text name="txt" value="$text"></Text>
          <Choices name="cls" toName="txt" choice="single">
            <Choice value="class_A"></Choice>
            <Choice value="class_B"></Choice>
          </Choices>
        </View>""",
            ),
            # setup API returns this model version
            '12345',
            # task data
            [{'text': 'text A'}, {'text': 'text B'}],
            # annotations: there is no any annotations
            [None, None],
            # predictions: 2 predictions, one from the new model version, second from old
            [
                {
                    'result': [
                        {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_A']}}
                    ],
                    'score': 0.95,
                    'model_version': '12345',
                },
                {
                    'result': [
                        {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_B']}}
                    ],
                    'score': 0.59,
                    'model_version': '12345_old',
                },
            ],
            # prediction call count is 1 only for the task with old predictions
            1,
            # ground_truth stats
            0,
            0,
            0,
        ),
        (
            # test name just for reference
            'Some predictions are outdated, other are up-to-date. project.model_version is outdated',
            # project config: contains actual model version
            dict(
                title='Test',
                model_version='12345_old',
                label_config="""
<View>
  <Text name="txt" value="$text"></Text>
  <Choices name="cls" toName="txt" choice="single">
    <Choice value="class_A"></Choice>
    <Choice value="class_B"></Choice>
  </Choices>
</View>""",
            ),
            # setup API returns this model version
            '12345',
            # task data
            [{'text': 'text A'}, {'text': 'text B'}],
            # annotations: there is no any annotations
            [None, None],
            # predictions: 2 predictions, one from the new model version, second from old
            [
                {
                    'result': [
                        {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_A']}}
                    ],
                    'score': 0.95,
                    'model_version': '12345',
                },
                {
                    'result': [
                        {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_B']}}
                    ],
                    'score': 0.59,
                    'model_version': '12345_old',
                },
            ],
            # prediction call count is 1 only for the task with old predictions
            1,
            # ground_truth stats
            0,
            0,
            0,
        ),
        (
            # test name just for reference
            'All tasks has no predictions',
            # project config: contains actual model version
            dict(
                title='Test',
                model_version='12345',
                label_config="""
<View>
  <Text name="txt" value="$text"></Text>
  <Choices name="cls" toName="txt" choice="single">
    <Choice value="class_A"></Choice>
    <Choice value="class_B"></Choice>
  </Choices>
</View>""",
            ),
            # setup API returns this model version
            '12345',
            # task data
            [{'text': 'text A'}, {'text': 'text B'}],
            # annotations: there is no any annotations
            [None, None],
            # there is no any predictions yet
            [None, None],
            # prediction call count for all tasks without predictions
            2,
            # ground_truth stats
            0,
            0,
            0,
        ),
        (
            # test name just for reference
            'Some tasks has no predictions, others are up-to-date',
            # project config: contains actual model version
            dict(
                title='Test',
                model_version='12345',
                label_config="""
                <View>
                <Text name="txt" value="$text"></Text>
                <Choices name="cls" toName="txt" choice="single">
                <Choice value="class_A"></Choice>
                <Choice value="class_B"></Choice>
                </Choices>
                </View>""",
            ),
            # setup API returns this model version
            '12345',
            # task data
            [{'text': 'text A'}, {'text': 'text B'}],
            # annotations: there is no any annotations
            [None, None],
            # there is only one prediction (since job has finished before processing all tasks)
            [
                {
                    'result': [
                        {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_A']}}
                    ],
                    'score': 0.95,
                    'model_version': '12345',
                },
                None,
            ],
            # prediction call count for all tasks without predictions
            1,
            # ground_truth stats
            0,
            0,
            0,
        ),
        (
            # test name just for reference
            'Some tasks has no predictions, others are up-to-date, labeled task contains ground_truth',
            # project config: contains actual model version
            dict(
                title='Test',
                model_version='12345',
                label_config="""
        <View>
        <Text name="txt" value="$text"></Text>
        <Choices name="cls" toName="txt" choice="single">
        <Choice value="class_A"></Choice>
        <Choice value="class_B"></Choice>
        </Choices>
        </View>""",
            ),
            # setup API returns this model version
            '12345',
            # task data
            [{'text': 'text A'}, {'text': 'text B'}],
            # annotations: first task has fitted ground_truth
            [
                None,
                {
                    'result': [
                        {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_A']}}
                    ],
                    'ground_truth': True,
                },
            ],
            # there is only one prediction (since job has finished before processing all tasks)
            [
                {
                    'result': [
                        {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_A']}}
                    ],
                    'score': 0.95,
                    'model_version': '12345',
                },
                None,
            ],
            # prediction call count for all tasks without predictions
            1,
            # ground_truth stats
            1,
            1,
            1,
        ),
        (
            # test name just for reference
            'Some tasks has no predictions, others are outdated',
            # project config: contains actual model version
            dict(
                title='Test',
                model_version='12345',
                label_config="""
        <View>
        <Text name="txt" value="$text"></Text>
        <Choices name="cls" toName="txt" choice="single">
        <Choice value="class_A"></Choice>
        <Choice value="class_B"></Choice>
        </Choices>
        </View>""",
            ),
            # setup API returns this model version
            '12345',
            # task data
            [{'text': 'text A'}, {'text': 'text B'}],
            # annotations: there is no any annotations
            [None, None],
            # there is only one prediction (since job has finished before processing all tasks)
            [
                {
                    'result': [
                        {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_A']}}
                    ],
                    'score': 0.95,
                    'model_version': '12345_old',
                },
                None,
            ],
            # prediction call count for all tasks without up-to-date predictions
            2,
            # ground_truth stats
            0,
            0,
            0,
        ),
        (
            # test name just for reference
            'Some tasks has no predictions, others are outdated, project.model_version is outdated',
            # project config: contains actual model version
            dict(
                title='Test',
                model_version='12345_old',
                label_config="""
    <View>
    <Text name="txt" value="$text"></Text>
    <Choices name="cls" toName="txt" choice="single">
    <Choice value="class_A"></Choice>
    <Choice value="class_B"></Choice>
    </Choices>
    </View>""",
            ),
            # setup API returns this model version
            '12345',
            # task data
            [{'text': 'text A'}, {'text': 'text B'}],
            # annotations: there is no any annotations
            [None, None],
            # there is only one prediction (since job has finished before processing all tasks)
            [
                {
                    'result': [
                        {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_A']}}
                    ],
                    'score': 0.95,
                    'model_version': '12345_old',
                },
                None,
            ],
            # prediction call count for all tasks without up-to-date predictions
            2,
            # ground_truth stats
            0,
            0,
            0,
        ),
        (
            # test name just for reference
            'Some tasks has no predictions, others are outdated, others are up-to-date',
            # project config: contains actual model version
            dict(
                title='Test',
                model_version='12345_old',
                label_config="""
<View>
<Text name="txt" value="$text"></Text>
<Choices name="cls" toName="txt" choice="single">
<Choice value="class_A"></Choice>
<Choice value="class_B"></Choice>
</Choices>
</View>""",
            ),
            # setup API returns this model version
            '12345',
            # task data
            [{'text': 'text A'}, {'text': 'text A'}, {'text': 'text B'}],
            # annotations: there is no any annotations
            [None, None, None],
            # there is only one prediction (since job has finished before processing all tasks)
            [
                {
                    'result': [
                        {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_A']}}
                    ],
                    'score': 0.95,
                    'model_version': '12345_old',
                },
                {
                    'result': [
                        {'from_name': 'cls', 'to_name': 'txt', 'type': 'choices', 'value': {'choices': ['class_A']}}
                    ],
                    'score': 0.95,
                    'model_version': '12345',
                },
                None,
            ],
            # prediction call count for all tasks without up-to-date predictions
            2,
            # ground_truth stats
            0,
            0,
            0,
        ),
    ],
)
@pytest.mark.django_db
def test_predictions_with_partially_predicted_tasks(
    business_client,
    test_name,
    setup_returns_model_version,
    project_config,
    tasks,
    annotations,
    input_predictions,
    prediction_call_count,
    num_project_stats,
    num_ground_truth_in_stats,
    num_ground_truth_fit_predictions,
    mocker,
):
    project = make_project(project_config, business_client.user)
    ml_backend = MLBackend.objects.get(url='http://localhost:8999')
    ml_backend.model_version = project_config['model_version']
    ml_backend.save()
    for task, annotation, prediction in zip(tasks, annotations, input_predictions):
        task_obj = Task.objects.create(project=project, data=task)
        if annotation is not None:
            Annotation.objects.create(task=task_obj, **annotation)
        if prediction is not None:
            Prediction.objects.create(task=task_obj, project=task_obj.project, **prediction)

    # run prediction
    with requests_mock.Mocker() as m:
        m.register_uri(
            'POST', 'http://localhost:8999/setup', text=json.dumps({'model_version': setup_returns_model_version})
        )
        m.register_uri(
            'POST',
            'http://localhost:8999/predict',
            text=json.dumps(
                {
                    'results': [
                        {
                            'result': [
                                {
                                    'from_name': 'cls',
                                    'to_name': 'txt',
                                    'type': 'choices',
                                    'value': {'choices': ['class_A']},
                                }
                            ],
                            'score': 1,
                        }
                    ],
                    'model_version': setup_returns_model_version,
                }
            ),
        )

        r = run_task_predictions(business_client, project, mocker)
        assert r.status_code == 200
        assert len(list(filter(lambda h: h.url.endswith('predict'), m.request_history))) == prediction_call_count

        assert Prediction.objects.filter(project=project.id, model_version=setup_returns_model_version).count() == len(
            tasks
        )
        assert MLBackend.objects.get(url='http://localhost:8999').model_version == setup_returns_model_version


@pytest.mark.django_db
def test_interactive_annotating(business_client, configured_project):
    # create project with predefined task set
    ml_backend = configured_project.ml_backends.first()
    ml_backend.is_interactive = True
    ml_backend.save()

    task = configured_project.tasks.first()
    # run prediction
    with requests_mock.Mocker(real_http=True) as m:
        m.register_uri('POST', f'{ml_backend.url}/predict', json={'results': [{'x': 'x'}]}, status_code=200)

        r = business_client.post(
            f'/api/ml/{ml_backend.pk}/interactive-annotating',
            data=json.dumps(
                {
                    'task': task.id,
                    'context': {'y': 'y'},
                }
            ),
            content_type='application/json',
        )
        r.status_code = 200

        result = r.json()

        assert 'data' in result
        assert 'x' in result['data']
        assert result['data']['x'] == 'x'


@pytest.mark.django_db
def test_interactive_annotating_failing(business_client, configured_project):
    # create project with predefined task set
    ml_backend = configured_project.ml_backends.first()
    ml_backend.is_interactive = True
    ml_backend.save()

    task = configured_project.tasks.first()
    # run prediction

    r = business_client.post(
        f'/api/ml/{ml_backend.pk}/interactive-annotating',
        data=json.dumps(
            {
                'task': task.id,
                'context': {'y': 'y'},
            }
        ),
        content_type='application/json',
    )
    r.status_code = 200

    result = r.json()

    assert 'errors' in result

    # BAD ML RESPONSE
    with requests_mock.Mocker(real_http=True) as m:
        m.register_uri('POST', f'{ml_backend.url}/predict', json={'kebab': [[['eat']]]}, status_code=200)

        r = business_client.post(
            f'/api/ml/{ml_backend.pk}/interactive-annotating',
            data=json.dumps(
                {
                    'task': task.id,
                    'context': {'y': 'y'},
                }
            ),
            content_type='application/json',
        )
        r.status_code = 200

        result = r.json()

    assert 'errors' in result


@pytest.mark.django_db
def test_interactive_annotating_with_drafts(business_client, configured_project):
    """
    Test interactive annotating with drafts
    :param business_client:
    :param configured_project:
    :return:
    """
    # create project with predefined task set
    ml_backend = configured_project.ml_backends.first()
    ml_backend.is_interactive = True
    ml_backend.save()

    users = list(User.objects.all())

    task = configured_project.tasks.first()
    AnnotationDraft.objects.create(task=task, user=users[0], result={}, lead_time=1)
    AnnotationDraft.objects.create(task=task, user=users[1], result={}, lead_time=2)
    # run prediction
    with requests_mock.Mocker(real_http=True) as m:
        m.register_uri('POST', f'{ml_backend.url}/predict', json={'results': [{'x': 'x'}]}, status_code=200)

        r = business_client.post(
            f'/api/ml/{ml_backend.pk}/interactive-annotating',
            data=json.dumps(
                {
                    'task': task.id,
                    'context': {'y': 'y'},
                }
            ),
            content_type='application/json',
        )
        r.status_code = 200

        result = r.json()

        assert 'data' in result
        assert 'x' in result['data']
        assert result['data']['x'] == 'x'

        history = [req for req in m.request_history if 'predict' in req.path][0]
        assert history.text

        js = json.loads(history.text)

        assert len(js['tasks'][0]['drafts']) == 1


@pytest.mark.django_db
def test_predictions_meta(business_client, configured_project):
    from tasks.models import FailedPrediction, Prediction, PredictionMeta

    task = configured_project.tasks.first()

    # create Prediction
    prediction = Prediction.objects.create(
        task=task,
        project=task.project,
        result={
            'result': [
                {'from_name': 'text_class', 'to_name': 'text', 'type': 'choices', 'value': {'choices': ['class_A']}}
            ]
        },
        score=0.95,
        model_version='12345',
    )

    # create FailedPrediction
    failed_prediction = FailedPrediction.objects.create(
        task=task,
        project=task.project,
        message='error',
        model_version='12345',
    )

    # assert we can create PredictionMeta with Prediction
    p = PredictionMeta.objects.create(prediction=prediction)
    meta = PredictionMeta.objects.get(id=p.id)
    # assert default values like meta.inference_time == 0 and meta.failed_prediction == null
    assert meta.inference_time is None
    assert meta.failed_prediction is None

    # assert we can create PredictionMeta with FailedPrediction
    p = PredictionMeta.objects.create(failed_prediction=failed_prediction)
    meta = PredictionMeta.objects.get(id=p.id)
    assert meta.total_cost is None
    assert meta.prediction is None

    # assert it raise an exception if we create PredictionMeta with both Prediction and FailedPrediction
    with pytest.raises(Exception):
        PredictionMeta.objects.create(prediction=prediction, failed_prediction=failed_prediction)

    # assert it raises if no Prediction or FailedPrediction is provided
    with pytest.raises(Exception):
        PredictionMeta.objects.create()
