# This file was auto-generated by Fern from our API Definition.

import datetime as dt
import typing
from json.decoder import JSONDecodeError

from ...core.api_error import ApiError
from ...core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
from ...core.http_response import AsyncHttpResponse, HttpResponse
from ...core.jsonable_encoder import jsonable_encoder
from ...core.request_options import RequestOptions
from ...core.unchecked_base_model import construct_type
from ...types.cancel_model_run_response import CancelModelRunResponse
from ...types.model_run import ModelRun
from ...types.project_subset_enum import ProjectSubsetEnum
from .types.list_runs_request_project_subset import ListRunsRequestProjectSubset

# this is used as the default value for optional parameters
OMIT = typing.cast(typing.Any, ...)


class RawRunsClient:
    def __init__(self, *, client_wrapper: SyncClientWrapper):
        self._client_wrapper = client_wrapper

    def list(
        self,
        prompt_id: int,
        version_id: int,
        *,
        ordering: typing.Optional[str] = None,
        parent_model: typing.Optional[int] = None,
        project: typing.Optional[int] = None,
        project_subset: typing.Optional[ListRunsRequestProjectSubset] = None,
        request_options: typing.Optional[RequestOptions] = None,
    ) -> HttpResponse[typing.List[ModelRun]]:
        """
        <Card href="https://humansignal.com/goenterprise">
                <img style="pointer-events: none; margin-left: 0px; margin-right: 0px;" src="https://docs.humansignal.com/images/badge.svg" alt="Label Studio Enterprise badge"/>
                <p style="margin-top: 10px; font-size: 14px;">
                    This endpoint is not available in Label Studio Community Edition. [Learn more about Label Studio Enterprise](https://humansignal.com/goenterprise)
                </p>
            </Card>
        Get information (status, metadata, etc) about an existing inference run

        Parameters
        ----------
        prompt_id : int

        version_id : int

        ordering : typing.Optional[str]
            Which field to use when ordering the results.

        parent_model : typing.Optional[int]
            The ID of the parent model for this Inference Run

        project : typing.Optional[int]
            The ID of the project this Inference Run makes predictions on

        project_subset : typing.Optional[ListRunsRequestProjectSubset]
            Defines which tasks are operated on (e.g. HasGT will only operate on tasks with a ground truth annotation, but All will operate on all records)

        request_options : typing.Optional[RequestOptions]
            Request-specific configuration.

        Returns
        -------
        HttpResponse[typing.List[ModelRun]]

        """
        _response = self._client_wrapper.httpx_client.request(
            f"api/prompts/{jsonable_encoder(prompt_id)}/versions/{jsonable_encoder(version_id)}/inference-runs",
            method="GET",
            params={
                "ordering": ordering,
                "parent_model": parent_model,
                "project": project,
                "project_subset": project_subset,
            },
            request_options=request_options,
        )
        try:
            if 200 <= _response.status_code < 300:
                _data = typing.cast(
                    typing.List[ModelRun],
                    construct_type(
                        type_=typing.List[ModelRun],  # type: ignore
                        object_=_response.json(),
                    ),
                )
                return HttpResponse(response=_response, data=_data)
            _response_json = _response.json()
        except JSONDecodeError:
            raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
        raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)

    def create(
        self,
        prompt_id: int,
        version_id: int,
        *,
        project: int,
        job_id: typing.Optional[str] = OMIT,
        only_missing_predictions: typing.Optional[bool] = OMIT,
        organization: typing.Optional[int] = OMIT,
        predictions_updated_at: typing.Optional[dt.datetime] = OMIT,
        project_subset: typing.Optional[ProjectSubsetEnum] = OMIT,
        total_correct_predictions: typing.Optional[int] = OMIT,
        total_predictions: typing.Optional[int] = OMIT,
        total_tasks: typing.Optional[int] = OMIT,
        request_options: typing.Optional[RequestOptions] = None,
    ) -> HttpResponse[ModelRun]:
        """
        <Card href="https://humansignal.com/goenterprise">
                <img style="pointer-events: none; margin-left: 0px; margin-right: 0px;" src="https://docs.humansignal.com/images/badge.svg" alt="Label Studio Enterprise badge"/>
                <p style="margin-top: 10px; font-size: 14px;">
                    This endpoint is not available in Label Studio Community Edition. [Learn more about Label Studio Enterprise](https://humansignal.com/goenterprise)
                </p>
            </Card>
        Run a prompt inference.

        Parameters
        ----------
        prompt_id : int

        version_id : int

        project : int

        job_id : typing.Optional[str]
            Job ID for inference job for a ModelRun e.g. Adala job ID

        only_missing_predictions : typing.Optional[bool]
            When true, only tasks without successful predictions for this prompt version are submitted for inference.

        organization : typing.Optional[int]

        predictions_updated_at : typing.Optional[dt.datetime]

        project_subset : typing.Optional[ProjectSubsetEnum]

        total_correct_predictions : typing.Optional[int]

        total_predictions : typing.Optional[int]

        total_tasks : typing.Optional[int]

        request_options : typing.Optional[RequestOptions]
            Request-specific configuration.

        Returns
        -------
        HttpResponse[ModelRun]

        """
        _response = self._client_wrapper.httpx_client.request(
            f"api/prompts/{jsonable_encoder(prompt_id)}/versions/{jsonable_encoder(version_id)}/inference-runs",
            method="POST",
            json={
                "job_id": job_id,
                "only_missing_predictions": only_missing_predictions,
                "organization": organization,
                "predictions_updated_at": predictions_updated_at,
                "project": project,
                "project_subset": project_subset,
                "total_correct_predictions": total_correct_predictions,
                "total_predictions": total_predictions,
                "total_tasks": total_tasks,
            },
            headers={
                "content-type": "application/json",
            },
            request_options=request_options,
            omit=OMIT,
        )
        try:
            if 200 <= _response.status_code < 300:
                _data = typing.cast(
                    ModelRun,
                    construct_type(
                        type_=ModelRun,  # type: ignore
                        object_=_response.json(),
                    ),
                )
                return HttpResponse(response=_response, data=_data)
            _response_json = _response.json()
        except JSONDecodeError:
            raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
        raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)

    def cancel(
        self,
        prompt_id: int,
        version_id: int,
        inference_run_id: int,
        *,
        request_options: typing.Optional[RequestOptions] = None,
    ) -> HttpResponse[CancelModelRunResponse]:
        """
        <Card href="https://humansignal.com/goenterprise">
                <img style="pointer-events: none; margin-left: 0px; margin-right: 0px;" src="https://docs.humansignal.com/images/badge.svg" alt="Label Studio Enterprise badge"/>
                <p style="margin-top: 10px; font-size: 14px;">
                    This endpoint is not available in Label Studio Community Edition. [Learn more about Label Studio Enterprise](https://humansignal.com/goenterprise)
                </p>
            </Card>
        Cancel the inference run for the given api

        Parameters
        ----------
        prompt_id : int

        version_id : int

        inference_run_id : int

        request_options : typing.Optional[RequestOptions]
            Request-specific configuration.

        Returns
        -------
        HttpResponse[CancelModelRunResponse]

        """
        _response = self._client_wrapper.httpx_client.request(
            f"api/prompts/{jsonable_encoder(prompt_id)}/versions/{jsonable_encoder(version_id)}/inference-runs/{jsonable_encoder(inference_run_id)}/cancel",
            method="POST",
            request_options=request_options,
        )
        try:
            if 200 <= _response.status_code < 300:
                _data = typing.cast(
                    CancelModelRunResponse,
                    construct_type(
                        type_=CancelModelRunResponse,  # type: ignore
                        object_=_response.json(),
                    ),
                )
                return HttpResponse(response=_response, data=_data)
            _response_json = _response.json()
        except JSONDecodeError:
            raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
        raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)


class AsyncRawRunsClient:
    def __init__(self, *, client_wrapper: AsyncClientWrapper):
        self._client_wrapper = client_wrapper

    async def list(
        self,
        prompt_id: int,
        version_id: int,
        *,
        ordering: typing.Optional[str] = None,
        parent_model: typing.Optional[int] = None,
        project: typing.Optional[int] = None,
        project_subset: typing.Optional[ListRunsRequestProjectSubset] = None,
        request_options: typing.Optional[RequestOptions] = None,
    ) -> AsyncHttpResponse[typing.List[ModelRun]]:
        """
        <Card href="https://humansignal.com/goenterprise">
                <img style="pointer-events: none; margin-left: 0px; margin-right: 0px;" src="https://docs.humansignal.com/images/badge.svg" alt="Label Studio Enterprise badge"/>
                <p style="margin-top: 10px; font-size: 14px;">
                    This endpoint is not available in Label Studio Community Edition. [Learn more about Label Studio Enterprise](https://humansignal.com/goenterprise)
                </p>
            </Card>
        Get information (status, metadata, etc) about an existing inference run

        Parameters
        ----------
        prompt_id : int

        version_id : int

        ordering : typing.Optional[str]
            Which field to use when ordering the results.

        parent_model : typing.Optional[int]
            The ID of the parent model for this Inference Run

        project : typing.Optional[int]
            The ID of the project this Inference Run makes predictions on

        project_subset : typing.Optional[ListRunsRequestProjectSubset]
            Defines which tasks are operated on (e.g. HasGT will only operate on tasks with a ground truth annotation, but All will operate on all records)

        request_options : typing.Optional[RequestOptions]
            Request-specific configuration.

        Returns
        -------
        AsyncHttpResponse[typing.List[ModelRun]]

        """
        _response = await self._client_wrapper.httpx_client.request(
            f"api/prompts/{jsonable_encoder(prompt_id)}/versions/{jsonable_encoder(version_id)}/inference-runs",
            method="GET",
            params={
                "ordering": ordering,
                "parent_model": parent_model,
                "project": project,
                "project_subset": project_subset,
            },
            request_options=request_options,
        )
        try:
            if 200 <= _response.status_code < 300:
                _data = typing.cast(
                    typing.List[ModelRun],
                    construct_type(
                        type_=typing.List[ModelRun],  # type: ignore
                        object_=_response.json(),
                    ),
                )
                return AsyncHttpResponse(response=_response, data=_data)
            _response_json = _response.json()
        except JSONDecodeError:
            raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
        raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)

    async def create(
        self,
        prompt_id: int,
        version_id: int,
        *,
        project: int,
        job_id: typing.Optional[str] = OMIT,
        only_missing_predictions: typing.Optional[bool] = OMIT,
        organization: typing.Optional[int] = OMIT,
        predictions_updated_at: typing.Optional[dt.datetime] = OMIT,
        project_subset: typing.Optional[ProjectSubsetEnum] = OMIT,
        total_correct_predictions: typing.Optional[int] = OMIT,
        total_predictions: typing.Optional[int] = OMIT,
        total_tasks: typing.Optional[int] = OMIT,
        request_options: typing.Optional[RequestOptions] = None,
    ) -> AsyncHttpResponse[ModelRun]:
        """
        <Card href="https://humansignal.com/goenterprise">
                <img style="pointer-events: none; margin-left: 0px; margin-right: 0px;" src="https://docs.humansignal.com/images/badge.svg" alt="Label Studio Enterprise badge"/>
                <p style="margin-top: 10px; font-size: 14px;">
                    This endpoint is not available in Label Studio Community Edition. [Learn more about Label Studio Enterprise](https://humansignal.com/goenterprise)
                </p>
            </Card>
        Run a prompt inference.

        Parameters
        ----------
        prompt_id : int

        version_id : int

        project : int

        job_id : typing.Optional[str]
            Job ID for inference job for a ModelRun e.g. Adala job ID

        only_missing_predictions : typing.Optional[bool]
            When true, only tasks without successful predictions for this prompt version are submitted for inference.

        organization : typing.Optional[int]

        predictions_updated_at : typing.Optional[dt.datetime]

        project_subset : typing.Optional[ProjectSubsetEnum]

        total_correct_predictions : typing.Optional[int]

        total_predictions : typing.Optional[int]

        total_tasks : typing.Optional[int]

        request_options : typing.Optional[RequestOptions]
            Request-specific configuration.

        Returns
        -------
        AsyncHttpResponse[ModelRun]

        """
        _response = await self._client_wrapper.httpx_client.request(
            f"api/prompts/{jsonable_encoder(prompt_id)}/versions/{jsonable_encoder(version_id)}/inference-runs",
            method="POST",
            json={
                "job_id": job_id,
                "only_missing_predictions": only_missing_predictions,
                "organization": organization,
                "predictions_updated_at": predictions_updated_at,
                "project": project,
                "project_subset": project_subset,
                "total_correct_predictions": total_correct_predictions,
                "total_predictions": total_predictions,
                "total_tasks": total_tasks,
            },
            headers={
                "content-type": "application/json",
            },
            request_options=request_options,
            omit=OMIT,
        )
        try:
            if 200 <= _response.status_code < 300:
                _data = typing.cast(
                    ModelRun,
                    construct_type(
                        type_=ModelRun,  # type: ignore
                        object_=_response.json(),
                    ),
                )
                return AsyncHttpResponse(response=_response, data=_data)
            _response_json = _response.json()
        except JSONDecodeError:
            raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
        raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)

    async def cancel(
        self,
        prompt_id: int,
        version_id: int,
        inference_run_id: int,
        *,
        request_options: typing.Optional[RequestOptions] = None,
    ) -> AsyncHttpResponse[CancelModelRunResponse]:
        """
        <Card href="https://humansignal.com/goenterprise">
                <img style="pointer-events: none; margin-left: 0px; margin-right: 0px;" src="https://docs.humansignal.com/images/badge.svg" alt="Label Studio Enterprise badge"/>
                <p style="margin-top: 10px; font-size: 14px;">
                    This endpoint is not available in Label Studio Community Edition. [Learn more about Label Studio Enterprise](https://humansignal.com/goenterprise)
                </p>
            </Card>
        Cancel the inference run for the given api

        Parameters
        ----------
        prompt_id : int

        version_id : int

        inference_run_id : int

        request_options : typing.Optional[RequestOptions]
            Request-specific configuration.

        Returns
        -------
        AsyncHttpResponse[CancelModelRunResponse]

        """
        _response = await self._client_wrapper.httpx_client.request(
            f"api/prompts/{jsonable_encoder(prompt_id)}/versions/{jsonable_encoder(version_id)}/inference-runs/{jsonable_encoder(inference_run_id)}/cancel",
            method="POST",
            request_options=request_options,
        )
        try:
            if 200 <= _response.status_code < 300:
                _data = typing.cast(
                    CancelModelRunResponse,
                    construct_type(
                        type_=CancelModelRunResponse,  # type: ignore
                        object_=_response.json(),
                    ),
                )
                return AsyncHttpResponse(response=_response, data=_data)
            _response_json = _response.json()
        except JSONDecodeError:
            raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
        raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
