import pytest
from label_studio_sdk.client import LabelStudio
from label_studio_sdk.label_interface import LabelInterface


@pytest.mark.django_db
def test_batch_predictions_single_prediction_per_task(django_live_url, business_client, ml_backend_for_test_predict):
    ls = LabelStudio(base_url=django_live_url, api_key=business_client.api_key)
    label_config = LabelInterface.create(
        {
            'text': ('Text', {'name': 'text', 'value': '$text'}, ()),
            'label': (
                'Choices',
                {'name': 'label', 'toName': 'text', 'choice': 'single'},
                (
                    ('Choice', {'value': 'label_A'}, ()),
                    ('Choice', {'value': 'label_B'}, ()),
                ),
            ),
        }
    )
    p = ls.projects.create(
        title='New Project',
        label_config=label_config,
    )
    ls.projects.import_tasks(
        p.id,
        request=[
            {'data': {'text': 'test 1'}},
            {'data': {'text': 'test 2'}},
            {'data': {'text': 'test 3'}},
        ],
    )

    tasks = [task for task in ls.tasks.list(project=p.id)]
    assert len(tasks) == 3

    # setup ML backend with single prediction per task
    ls.ml.create(url='http://test.ml.backend.for.sdk.com:9092', project=p.id, title='ModelSingle')

    # confirm ML backend created
    ml_backend = ls.ml.list(project=p.id)
    assert len(ml_backend) == 1
    assert ml_backend[0].title == 'ModelSingle'
    assert ml_backend[0].url == 'http://test.ml.backend.for.sdk.com:9092'

    # batch predict tasks via actions
    ls.actions.create(
        id='retrieve_tasks_predictions',
        project=p.id,
        selected_items={'all': True, 'excluded': [tasks[1].id]},
    )

    # get all predictions in project
    predictions = ls.predictions.list(project=p.id)

    # check that only 2 predictions were created
    assert len(predictions) == 2

    # check that the first prediction has the correct value
    assert predictions[0].result[0]['value']['choices'][0] == 'label_A'
    assert predictions[0].model_version == 'ModelSingle'

    # check that the second prediction has the correct value
    assert predictions[1].result[0]['value']['choices'][0] == 'label_A'
    assert predictions[1].model_version == 'ModelSingle'

    # additionally let's test actions: convert predictions to annotations
    ls.actions.create(
        id='predictions_to_annotations',
        project=p.id,
        selected_items={
            'all': False,
            'included': [
                predictions[0].task,
                predictions[1].task,
                # also emulate user error when trying to convert task with no predictions
                tasks[1].id,
            ],
        },
    )

    # get all annotations in project
    for task in ls.tasks.list(project=p.id, fields='all'):
        if task.id == tasks[1].id:
            assert not task.annotations
            assert not task.predictions
        else:
            assert len(task.annotations) == 1
            assert task.annotations[0]['result'][0]['value']['choices'][0] == 'label_A'

            assert len(task.predictions) == 1
            assert task.predictions[0].result[0]['value']['choices'][0] == 'label_A'
            assert task.predictions[0].model_version == 'ModelSingle'
            assert task.predictions[0].score == 0.1


@pytest.mark.django_db
def test_batch_predictions_multiple_predictions_per_task(
    django_live_url, business_client, ml_backend_for_test_predict
):
    ls = LabelStudio(base_url=django_live_url, api_key=business_client.api_key)
    li = LabelInterface(
        """
            <View>
              <Text name="text" value="$text"/>
              <Choices name="label" toName="text" choice="single">
                <Choice value="label_A"></Choice>
                <Choice value="label_B"></Choice>
              </Choices>
            </View>"""
    )
    p = ls.projects.create(
        title='New Project',
        label_config=li._config,
    )
    ls.projects.import_tasks(
        p.id,
        request=[
            {'data': {'text': 'test 1'}},
            {'data': {'text': 'test 2'}},
            {'data': {'text': 'test 3'}},
        ],
    )

    tasks = [task for task in ls.tasks.list(project=p.id)]
    assert len(tasks) == 3

    # setup ML backend with multiple predictions per task
    ls.ml.create(url='http://test.ml.backend.for.sdk.com:9093', project=p.id, title='ModelMultiple')

    # batch predict tasks via actions
    ls.actions.create(
        id='retrieve_tasks_predictions',
        project=p.id,
        selected_items={'all': False, 'included': [tasks[0].id, tasks[2].id]},
    )

    # get all predictions in project
    predictions = ls.predictions.list(project=p.id)

    # check that there are 4 predictions as 2 tasks were predicted
    assert len(predictions) == 4

    for task in ls.tasks.list(project=p.id, fields='all'):
        if task.id == tasks[1].id:
            assert not task.predictions
        else:
            assert len(task.predictions) == 2

            for i, prediction in enumerate(task.predictions):
                assert prediction.result[0]['value']['choices'][0] == f'label_{["A", "B"][i]}'
                assert prediction.model_version == f'Model{"AB"[i]}'
                assert prediction.score == 0.2 if i == 0 else 0.3
