# Generated by Django 3.2.25 on 2024-09-12 21:59

from django.db import migrations, models, transaction
import django.db.models.deletion
import django_migration_linter as linter
from core.redis import start_job_async_or_sync
from ml_models.models import ThirdPartyModelVersion
from ml_model_providers.models import ModelProviderConnection, ModelProviders


def _fill_model_version_model_provider_connection(db_alias: str):
    for provider in [ModelProviders.OPENAI, ModelProviders.AZURE_OPENAI]:
        this_provider_model_versions = (
            ThirdPartyModelVersion.objects.using(db_alias)
            .filter(provider=provider)
            .values('id', 'organization_id', 'provider_model_id')
        )
        for provider_model_version in this_provider_model_versions:
            connection_ids = ModelProviderConnection.objects.using(db_alias).filter(
                organization_id=provider_model_version['organization_id'],
                provider=provider,
                **({'deployment_name': provider_model_version['provider_model_id']} if provider == ModelProviders.AZURE_OPENAI else {}),
            ).values_list('id', flat=True)[:1]
            connection_id = connection_ids[0] if connection_ids else None
            ThirdPartyModelVersion.objects.using(db_alias).filter(id=provider_model_version['id']).update(model_provider_connection_id=connection_id)

def forwards(apps, schema_editor):
    db_alias = schema_editor.connection.alias
    start_job_async_or_sync(_fill_model_version_model_provider_connection, db_alias=db_alias)


def backwards(apps, schema_editor):
    pass


class Migration(migrations.Migration):
    atomic = False

    dependencies = [
        ('ml_model_providers', '0003_modelproviderconnection_cached_available_models'),
        ('ml_models', '0010_modelinterface_skill_name'),
    ]

    operations = [
        linter.IgnoreMigration(),
        migrations.AddField(
            model_name='thirdpartymodelversion',
            name='model_provider_connection',
            field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='model_versions', to='ml_model_providers.modelproviderconnection'),
        ),
        migrations.RunPython(forwards, backwards)
    ]
