"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license.
"""
"""
Main module with the bulk_update function.
"""
import itertools
from collections import defaultdict

from django.db import connections, models
from django.db.models.sql import UpdateQuery


def _get_db_type(field, connection):
    if isinstance(field, (models.PositiveSmallIntegerField, models.PositiveIntegerField)):
        return field.db_type(connection).split(' ', 1)[0]

    return field.db_type(connection)


def _as_sql(obj, field, query, compiler, connection):
    value = getattr(obj, field.attname)

    if hasattr(value, 'resolve_expression'):
        value = value.resolve_expression(query, allow_joins=False, for_save=True)
    else:
        value = field.get_db_prep_save(value, connection=connection)

    if hasattr(value, 'as_sql'):
        placeholder, value = compiler.compile(value)
        if isinstance(value, list):
            value = tuple(value)
    else:
        placeholder = '%s'

    return value, placeholder


def flatten(l, types=(list, float)):  # noqa: E741
    """
    Flat nested list of lists into a single list.
    """
    l = [item if isinstance(item, types) else [item] for item in l]  # noqa: E741
    return [item for sublist in l for item in sublist]


def grouper(iterable, size):
    # http://stackoverflow.com/a/8991553
    it = iter(iterable)
    while True:
        chunk = tuple(itertools.islice(it, size))
        if not chunk:
            return
        yield chunk


def validate_fields(meta, fields):

    fields = frozenset(fields)
    field_names = set()

    for field in meta.fields:
        if not field.primary_key:
            field_names.add(field.name)

            if field.name != field.attname:
                field_names.add(field.attname)

    non_model_fields = fields.difference(field_names)

    if non_model_fields:
        raise TypeError('These fields are not present in ' 'current meta: {}'.format(', '.join(non_model_fields)))


def get_fields(update_fields, exclude_fields, meta, obj=None):

    deferred_fields = set()

    if update_fields is not None:
        validate_fields(meta, update_fields)
    elif obj:
        deferred_fields = obj.get_deferred_fields()

    if exclude_fields is None:
        exclude_fields = set()
    else:
        exclude_fields = set(exclude_fields)
        validate_fields(meta, exclude_fields)

    exclude_fields |= deferred_fields

    fields = [
        field
        for field in meta.concrete_fields
        if (
            not field.primary_key
            and field.attname not in deferred_fields
            and field.attname not in exclude_fields
            and field.name not in exclude_fields
            and (update_fields is None or field.attname in update_fields or field.name in update_fields)
        )
    ]

    return fields


def bulk_update(
    objs, meta=None, update_fields=None, exclude_fields=None, using='default', batch_size=None, pk_field='pk'
):
    assert batch_size is None or batch_size > 0

    # force to retrieve objs from the DB at the beginning,
    # to avoid multiple subsequent queries
    objs = list(objs)
    if not objs:
        return
    batch_size = batch_size or len(objs)

    if meta:
        fields = get_fields(update_fields, exclude_fields, meta)
    else:
        meta = objs[0]._meta
        if update_fields is not None:
            fields = get_fields(update_fields, exclude_fields, meta, objs[0])
        else:
            fields = None

    if fields is not None and len(fields) == 0:
        return

    if pk_field == 'pk':
        pk_field = meta.get_field(meta.pk.name)
    else:
        pk_field = meta.get_field(pk_field)

    connection = connections[using]
    query = UpdateQuery(meta.model)
    compiler = query.get_compiler(connection=connection)

    template = '"{column}" = CAST(CASE "{pk_column}" {cases}ELSE "{column}" END AS {type})'

    case_template = 'WHEN %s THEN {} '

    lenpks = 0
    for objs_batch in grouper(objs, batch_size):

        pks = []
        parameters = defaultdict(list)
        placeholders = defaultdict(list)

        for obj in objs_batch:

            pk_value, _ = _as_sql(obj, pk_field, query, compiler, connection)
            pks.append(pk_value)

            loaded_fields = fields or get_fields(update_fields, exclude_fields, meta, obj)

            for field in loaded_fields:
                value, placeholder = _as_sql(obj, field, query, compiler, connection)
                parameters[field].extend(flatten([pk_value, value], types=tuple))
                placeholders[field].append(placeholder)

        values = ', '.join(
            template.format(
                column=field.column,
                pk_column=pk_field.column,
                cases=(case_template * len(placeholders[field])).format(*placeholders[field]),
                type=_get_db_type(field, connection=connection),
            )
            for field in parameters.keys()
        )

        parameters = flatten(parameters.values(), types=list)
        parameters.extend(pks)

        n_pks = len(pks)
        del pks

        dbtable = '"{}"'.format(meta.db_table)

        in_clause = '"{pk_column}" in ({pks})'.format(
            pk_column=pk_field.column,
            pks=', '.join(itertools.repeat('%s', n_pks)),
        )

        sql = 'UPDATE {dbtable} SET {values} WHERE {in_clause}'.format(  # nosec
            dbtable=dbtable,
            values=values,
            in_clause=in_clause,
        )
        del values

        lenpks += n_pks

        connection.cursor().execute(sql, parameters)

    return lenpks
