import sys
import json
import time
from functools import wraps
from collections.abc import Iterable

import sentry_sdk
from sentry_sdk import consts
from sentry_sdk.ai.monitoring import record_token_usage
from sentry_sdk.ai.utils import (
    set_data_normalized,
    normalize_message_roles,
    truncate_and_annotate_messages,
    truncate_and_annotate_embedding_inputs,
)
from sentry_sdk.ai._openai_completions_api import (
    _is_system_instruction as _is_system_instruction_completions,
    _get_system_instructions as _get_system_instructions_completions,
    _transform_system_instructions,
    _get_text_items,
)
from sentry_sdk.ai._openai_responses_api import (
    _is_system_instruction as _is_system_instruction_responses,
    _get_system_instructions as _get_system_instructions_responses,
)
from sentry_sdk.consts import SPANDATA
from sentry_sdk.integrations import DidNotEnable, Integration
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.tracing_utils import set_span_errored
from sentry_sdk.utils import (
    capture_internal_exceptions,
    event_from_exception,
    safe_serialize,
    reraise,
)

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from typing import (
        Any,
        List,
        Optional,
        Callable,
        AsyncIterator,
        Iterator,
        Union,
        Iterable,
    )
    from sentry_sdk.tracing import Span
    from sentry_sdk._types import TextPart

    from openai.types.responses.response_usage import ResponseUsage
    from openai.types.responses import (
        ResponseInputParam,
        SequenceNotStr,
        ResponseStreamEvent,
    )
    from openai.types import CompletionUsage
    from openai import Omit

try:
    try:
        from openai import NotGiven
    except ImportError:
        NotGiven = None

    try:
        from openai import Omit
    except ImportError:
        Omit = None

    from openai.resources.chat.completions import Completions, AsyncCompletions
    from openai.resources import Embeddings, AsyncEmbeddings

    from openai import Stream, AsyncStream

    if TYPE_CHECKING:
        from openai.types.chat import (
            ChatCompletionMessageParam,
            ChatCompletionChunk,
        )
except ImportError:
    raise DidNotEnable("OpenAI not installed")

RESPONSES_API_ENABLED = True
try:
    # responses API support was introduced in v1.66.0
    from openai.resources.responses import Responses, AsyncResponses
    from openai.types.responses.response_completed_event import ResponseCompletedEvent
except ImportError:
    RESPONSES_API_ENABLED = False


class OpenAIIntegration(Integration):
    identifier = "openai"
    origin = f"auto.ai.{identifier}"

    def __init__(
        self: "OpenAIIntegration",
        include_prompts: bool = True,
        tiktoken_encoding_name: "Optional[str]" = None,
    ) -> None:
        self.include_prompts = include_prompts

        self.tiktoken_encoding = None
        if tiktoken_encoding_name is not None:
            import tiktoken  # type: ignore

            self.tiktoken_encoding = tiktoken.get_encoding(tiktoken_encoding_name)

    @staticmethod
    def setup_once() -> None:
        Completions.create = _wrap_chat_completion_create(Completions.create)
        AsyncCompletions.create = _wrap_async_chat_completion_create(
            AsyncCompletions.create
        )

        Embeddings.create = _wrap_embeddings_create(Embeddings.create)
        AsyncEmbeddings.create = _wrap_async_embeddings_create(AsyncEmbeddings.create)

        if RESPONSES_API_ENABLED:
            Responses.create = _wrap_responses_create(Responses.create)
            AsyncResponses.create = _wrap_async_responses_create(AsyncResponses.create)

    def count_tokens(self: "OpenAIIntegration", s: str) -> int:
        if self.tiktoken_encoding is None:
            return 0
        try:
            return len(self.tiktoken_encoding.encode_ordinary(s))
        except Exception:
            return 0


def _capture_exception(exc: "Any", manual_span_cleanup: bool = True) -> None:
    # Close an eventually open span
    # We need to do this by hand because we are not using the start_span context manager
    current_span = sentry_sdk.get_current_span()
    set_span_errored(current_span)

    if manual_span_cleanup and current_span is not None:
        current_span.__exit__(None, None, None)

    event, hint = event_from_exception(
        exc,
        client_options=sentry_sdk.get_client().options,
        mechanism={"type": "openai", "handled": False},
    )
    sentry_sdk.capture_event(event, hint=hint)


def _has_attr_and_is_int(
    token_usage: "Union[CompletionUsage, ResponseUsage]", attr_name: str
) -> bool:
    return hasattr(token_usage, attr_name) and isinstance(
        getattr(token_usage, attr_name, None), int
    )


def _calculate_completions_token_usage(
    messages: "Optional[Iterable[ChatCompletionMessageParam]]",
    response: "Any",
    span: "Span",
    streaming_message_responses: "Optional[List[str]]",
    streaming_message_total_token_usage: "Optional[CompletionUsage]",
    count_tokens: "Callable[..., Any]",
) -> None:
    """Extract and record token usage from a Chat Completions API response."""
    input_tokens: "Optional[int]" = 0
    input_tokens_cached: "Optional[int]" = 0
    output_tokens: "Optional[int]" = 0
    output_tokens_reasoning: "Optional[int]" = 0
    total_tokens: "Optional[int]" = 0
    usage = None

    if streaming_message_total_token_usage is not None:
        usage = streaming_message_total_token_usage
    elif hasattr(response, "usage"):
        usage = response.usage

    if usage is not None:
        if _has_attr_and_is_int(usage, "prompt_tokens"):
            input_tokens = usage.prompt_tokens
        if _has_attr_and_is_int(usage, "completion_tokens"):
            output_tokens = usage.completion_tokens
        if _has_attr_and_is_int(usage, "total_tokens"):
            total_tokens = usage.total_tokens

        if hasattr(usage, "prompt_tokens_details"):
            cached = getattr(usage.prompt_tokens_details, "cached_tokens", None)
            if isinstance(cached, int):
                input_tokens_cached = cached

        if hasattr(usage, "completion_tokens_details"):
            reasoning = getattr(
                usage.completion_tokens_details, "reasoning_tokens", None
            )
            if isinstance(reasoning, int):
                output_tokens_reasoning = reasoning

    # Manually count input tokens
    if input_tokens == 0:
        for message in messages or []:
            if isinstance(message, str):
                input_tokens += count_tokens(message)
                continue
            elif isinstance(message, dict):
                message_content = message.get("content")
                if message_content is None:
                    continue
                text_items = _get_text_items(message_content)
                input_tokens += sum(count_tokens(text) for text in text_items)
                continue

    # Manually count output tokens
    if output_tokens == 0:
        if streaming_message_responses is not None:
            for message in streaming_message_responses:
                output_tokens += count_tokens(message)
        elif hasattr(response, "choices"):
            for choice in response.choices:
                if hasattr(choice, "message") and hasattr(choice.message, "content"):
                    output_tokens += count_tokens(choice.message.content)

    # Do not set token data if it is 0
    input_tokens = input_tokens or None
    input_tokens_cached = input_tokens_cached or None
    output_tokens = output_tokens or None
    output_tokens_reasoning = output_tokens_reasoning or None
    total_tokens = total_tokens or None

    record_token_usage(
        span,
        input_tokens=input_tokens,
        input_tokens_cached=input_tokens_cached,
        output_tokens=output_tokens,
        output_tokens_reasoning=output_tokens_reasoning,
        total_tokens=total_tokens,
    )


def _calculate_responses_token_usage(
    input: "Any",
    response: "Any",
    span: "Span",
    streaming_message_responses: "Optional[List[str]]",
    count_tokens: "Callable[..., Any]",
) -> None:
    """Extract and record token usage from a Responses API response."""
    input_tokens: "Optional[int]" = 0
    input_tokens_cached: "Optional[int]" = 0
    output_tokens: "Optional[int]" = 0
    output_tokens_reasoning: "Optional[int]" = 0
    total_tokens: "Optional[int]" = 0

    if hasattr(response, "usage"):
        usage = response.usage

        if _has_attr_and_is_int(usage, "input_tokens"):
            input_tokens = usage.input_tokens
        if _has_attr_and_is_int(usage, "output_tokens"):
            output_tokens = usage.output_tokens
        if _has_attr_and_is_int(usage, "total_tokens"):
            total_tokens = usage.total_tokens

        if hasattr(usage, "input_tokens_details"):
            cached = getattr(usage.input_tokens_details, "cached_tokens", None)
            if isinstance(cached, int):
                input_tokens_cached = cached

        if hasattr(usage, "output_tokens_details"):
            reasoning = getattr(usage.output_tokens_details, "reasoning_tokens", None)
            if isinstance(reasoning, int):
                output_tokens_reasoning = reasoning

    # Manually count input tokens
    if input_tokens == 0:
        for message in input or []:
            if isinstance(message, str):
                input_tokens += count_tokens(message)
                continue
            elif isinstance(message, dict):
                message_content = message.get("content")
                if message_content is None:
                    continue
                # Deliberate use of Completions function for both Completions and Responses input format.
                text_items = _get_text_items(message_content)
                input_tokens += sum(count_tokens(text) for text in text_items)
                continue

    # Manually count output tokens
    if output_tokens == 0:
        if streaming_message_responses is not None:
            for message in streaming_message_responses:
                output_tokens += count_tokens(message)
        elif hasattr(response, "output"):
            for output_item in response.output:
                if hasattr(output_item, "content"):
                    for content_item in output_item.content:
                        if hasattr(content_item, "text"):
                            output_tokens += count_tokens(content_item.text)

    # Do not set token data if it is 0
    input_tokens = input_tokens or None
    input_tokens_cached = input_tokens_cached or None
    output_tokens = output_tokens or None
    output_tokens_reasoning = output_tokens_reasoning or None
    total_tokens = total_tokens or None

    record_token_usage(
        span,
        input_tokens=input_tokens,
        input_tokens_cached=input_tokens_cached,
        output_tokens=output_tokens,
        output_tokens_reasoning=output_tokens_reasoning,
        total_tokens=total_tokens,
    )


def _set_responses_api_input_data(
    span: "Span",
    kwargs: "dict[str, Any]",
    integration: "OpenAIIntegration",
) -> None:
    explicit_instructions: "Union[Optional[str], Omit]" = kwargs.get("instructions")
    messages: "Optional[Union[str, ResponseInputParam]]" = kwargs.get("input")

    tools = kwargs.get("tools")
    if tools is not None and _is_given(tools) and len(tools) > 0:
        set_data_normalized(
            span, SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS, safe_serialize(tools)
        )

    model = kwargs.get("model")
    if model is not None:
        span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model)

    max_tokens = kwargs.get("max_output_tokens")
    if max_tokens is not None and _is_given(max_tokens):
        span.set_data(SPANDATA.GEN_AI_REQUEST_MAX_TOKENS, max_tokens)

    temperature = kwargs.get("temperature")
    if temperature is not None and _is_given(temperature):
        span.set_data(SPANDATA.GEN_AI_REQUEST_TEMPERATURE, temperature)

    top_p = kwargs.get("top_p")
    if top_p is not None and _is_given(top_p):
        span.set_data(SPANDATA.GEN_AI_REQUEST_TOP_P, top_p)

    if not should_send_default_pii() or not integration.include_prompts:
        set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "responses")
        return

    if (
        messages is None
        and explicit_instructions is not None
        and _is_given(explicit_instructions)
    ):
        span.set_data(
            SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS,
            json.dumps(
                [
                    {
                        "type": "text",
                        "content": explicit_instructions,
                    }
                ]
            ),
        )

        set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "responses")
        return

    if messages is None:
        set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "responses")
        return

    instructions_text_parts: "list[TextPart]" = []
    if explicit_instructions is not None and _is_given(explicit_instructions):
        instructions_text_parts.append(
            {
                "type": "text",
                "content": explicit_instructions,
            }
        )

    system_instructions = _get_system_instructions_responses(messages)
    # Deliberate use of function accepting completions API type because
    # of shared structure FOR THIS PURPOSE ONLY.
    instructions_text_parts += _transform_system_instructions(system_instructions)

    if len(instructions_text_parts) > 0:
        span.set_data(
            SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS,
            json.dumps(instructions_text_parts),
        )

    if isinstance(messages, str):
        normalized_messages = normalize_message_roles([messages])  # type: ignore
        scope = sentry_sdk.get_current_scope()
        messages_data = truncate_and_annotate_messages(normalized_messages, span, scope)
        if messages_data is not None:
            set_data_normalized(
                span, SPANDATA.GEN_AI_REQUEST_MESSAGES, messages_data, unpack=False
            )

        set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "responses")
        return

    non_system_messages = [
        message for message in messages if not _is_system_instruction_responses(message)
    ]
    if len(non_system_messages) > 0:
        normalized_messages = normalize_message_roles(non_system_messages)
        scope = sentry_sdk.get_current_scope()
        messages_data = truncate_and_annotate_messages(normalized_messages, span, scope)
        if messages_data is not None:
            set_data_normalized(
                span, SPANDATA.GEN_AI_REQUEST_MESSAGES, messages_data, unpack=False
            )

    set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "responses")


def _set_completions_api_input_data(
    span: "Span",
    kwargs: "dict[str, Any]",
    integration: "OpenAIIntegration",
) -> None:
    messages: "Optional[Union[str, Iterable[ChatCompletionMessageParam]]]" = kwargs.get(
        "messages"
    )

    tools = kwargs.get("tools")
    if tools is not None and _is_given(tools) and len(tools) > 0:
        set_data_normalized(
            span, SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS, safe_serialize(tools)
        )

    model = kwargs.get("model")
    if model is not None:
        span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model)

    max_tokens = kwargs.get("max_tokens")
    if max_tokens is not None and _is_given(max_tokens):
        span.set_data(SPANDATA.GEN_AI_REQUEST_MAX_TOKENS, max_tokens)

    presence_penalty = kwargs.get("presence_penalty")
    if presence_penalty is not None and _is_given(presence_penalty):
        span.set_data(SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY, presence_penalty)

    frequency_penalty = kwargs.get("frequency_penalty")
    if frequency_penalty is not None and _is_given(frequency_penalty):
        span.set_data(SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY, frequency_penalty)

    temperature = kwargs.get("temperature")
    if temperature is not None and _is_given(temperature):
        span.set_data(SPANDATA.GEN_AI_REQUEST_TEMPERATURE, temperature)

    top_p = kwargs.get("top_p")
    if top_p is not None and _is_given(top_p):
        span.set_data(SPANDATA.GEN_AI_REQUEST_TOP_P, top_p)

    if (
        not should_send_default_pii()
        or not integration.include_prompts
        or messages is None
    ):
        set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "chat")
        return

    if isinstance(messages, str):
        normalized_messages = normalize_message_roles([messages])  # type: ignore
        scope = sentry_sdk.get_current_scope()
        messages_data = truncate_and_annotate_messages(normalized_messages, span, scope)
        if messages_data is not None:
            set_data_normalized(
                span, SPANDATA.GEN_AI_REQUEST_MESSAGES, messages_data, unpack=False
            )
        set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "chat")
        return

    # dict special case following https://github.com/openai/openai-python/blob/3e0c05b84a2056870abf3bd6a5e7849020209cc3/src/openai/_utils/_transform.py#L194-L197
    if not isinstance(messages, Iterable) or isinstance(messages, dict):
        set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "chat")
        return

    messages = list(messages)
    kwargs["messages"] = messages

    system_instructions = _get_system_instructions_completions(messages)
    if len(system_instructions) > 0:
        span.set_data(
            SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS,
            json.dumps(_transform_system_instructions(system_instructions)),
        )

    non_system_messages = [
        message
        for message in messages
        if not _is_system_instruction_completions(message)
    ]
    if len(non_system_messages) > 0:
        normalized_messages = normalize_message_roles(non_system_messages)
        scope = sentry_sdk.get_current_scope()
        messages_data = truncate_and_annotate_messages(normalized_messages, span, scope)
        if messages_data is not None:
            set_data_normalized(
                span, SPANDATA.GEN_AI_REQUEST_MESSAGES, messages_data, unpack=False
            )

    set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "chat")


def _set_embeddings_input_data(
    span: "Span",
    kwargs: "dict[str, Any]",
    integration: "OpenAIIntegration",
) -> None:
    messages: "Union[str, SequenceNotStr[str], Iterable[int], Iterable[Iterable[int]]]" = kwargs.get(
        "input"
    )

    model = kwargs.get("model")
    if model is not None:
        span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model)

    if (
        not should_send_default_pii()
        or not integration.include_prompts
        or messages is None
    ):
        set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")

        return

    if isinstance(messages, str):
        set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")

        normalized_messages = normalize_message_roles([messages])  # type: ignore
        scope = sentry_sdk.get_current_scope()
        messages_data = truncate_and_annotate_embedding_inputs(
            normalized_messages, span, scope
        )
        if messages_data is not None:
            set_data_normalized(
                span, SPANDATA.GEN_AI_EMBEDDINGS_INPUT, messages_data, unpack=False
            )

        return

    # dict special case following https://github.com/openai/openai-python/blob/3e0c05b84a2056870abf3bd6a5e7849020209cc3/src/openai/_utils/_transform.py#L194-L197
    if not isinstance(messages, Iterable) or isinstance(messages, dict):
        set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")
        return

    messages = list(messages)
    kwargs["input"] = messages

    if len(messages) > 0:
        normalized_messages = normalize_message_roles(messages)
        scope = sentry_sdk.get_current_scope()
        messages_data = truncate_and_annotate_embedding_inputs(
            normalized_messages, span, scope
        )
        if messages_data is not None:
            set_data_normalized(
                span, SPANDATA.GEN_AI_EMBEDDINGS_INPUT, messages_data, unpack=False
            )

    set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")


def _set_common_output_data(
    span: "Span",
    response: "Any",
    input: "Any",
    integration: "OpenAIIntegration",
    finish_span: bool = True,
) -> None:
    if hasattr(response, "model"):
        set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_MODEL, response.model)

    # Chat Completions API
    if hasattr(response, "choices"):
        if should_send_default_pii() and integration.include_prompts:
            response_text = [
                choice.message.model_dump()
                for choice in response.choices
                if choice.message is not None
            ]
            if len(response_text) > 0:
                set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, response_text)

        _calculate_completions_token_usage(
            messages=input,
            response=response,
            span=span,
            streaming_message_responses=None,
            streaming_message_total_token_usage=None,
            count_tokens=integration.count_tokens,
        )

        if finish_span:
            span.__exit__(None, None, None)

    # Responses API
    elif hasattr(response, "output"):
        if should_send_default_pii() and integration.include_prompts:
            output_messages: "dict[str, list[Any]]" = {
                "response": [],
                "tool": [],
            }

            for output in response.output:
                if output.type == "function_call":
                    output_messages["tool"].append(output.dict())
                elif output.type == "message":
                    for output_message in output.content:
                        try:
                            output_messages["response"].append(output_message.text)
                        except AttributeError:
                            # Unknown output message type, just return the json
                            output_messages["response"].append(output_message.dict())

            if len(output_messages["tool"]) > 0:
                set_data_normalized(
                    span,
                    SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
                    output_messages["tool"],
                    unpack=False,
                )

            if len(output_messages["response"]) > 0:
                set_data_normalized(
                    span, SPANDATA.GEN_AI_RESPONSE_TEXT, output_messages["response"]
                )

        _calculate_responses_token_usage(
            input=input,
            response=response,
            span=span,
            streaming_message_responses=None,
            count_tokens=integration.count_tokens,
        )

        if finish_span:
            span.__exit__(None, None, None)
    # Embeddings API (fallback for responses with neither choices nor output)
    else:
        _calculate_completions_token_usage(
            messages=input,
            response=response,
            span=span,
            streaming_message_responses=None,
            streaming_message_total_token_usage=None,
            count_tokens=integration.count_tokens,
        )
        if finish_span:
            span.__exit__(None, None, None)


def _new_chat_completion_common(f: "Any", *args: "Any", **kwargs: "Any") -> "Any":
    integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
    if integration is None:
        return f(*args, **kwargs)

    if "messages" not in kwargs:
        # invalid call (in all versions of openai), let it return error
        return f(*args, **kwargs)

    try:
        iter(kwargs["messages"])
    except TypeError:
        # invalid call (in all versions), messages must be iterable
        return f(*args, **kwargs)

    model = kwargs.get("model")

    span = sentry_sdk.start_span(
        op=consts.OP.GEN_AI_CHAT,
        name=f"chat {model}",
        origin=OpenAIIntegration.origin,
    )
    span.__enter__()

    span.set_data(SPANDATA.GEN_AI_SYSTEM, "openai")

    # Same bool handling as in https://github.com/openai/openai-python/blob/acd0c54d8a68efeedde0e5b4e6c310eef1ce7867/src/openai/resources/completions.py#L585
    is_streaming_response = kwargs.get("stream", False) or False
    span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, is_streaming_response)

    _set_completions_api_input_data(span, kwargs, integration)

    start_time = time.perf_counter()
    response = yield f, args, kwargs

    # Attribute check to fail gracefully if the attribute is not present in future `openai` versions.
    if isinstance(response, Stream) and hasattr(response, "_iterator"):
        messages = kwargs.get("messages")

        if messages is not None and isinstance(messages, str):
            messages = [messages]

        response._iterator = _wrap_synchronous_completions_chunk_iterator(
            span=span,
            integration=integration,
            start_time=start_time,
            messages=messages,
            response=response,
            old_iterator=response._iterator,
            finish_span=True,
        )

    # Attribute check to fail gracefully if the attribute is not present in future `openai` versions.
    elif isinstance(response, AsyncStream) and hasattr(response, "_iterator"):
        messages = kwargs.get("messages")

        if messages is not None and isinstance(messages, str):
            messages = [messages]

        response._iterator = _wrap_asynchronous_completions_chunk_iterator(
            span=span,
            integration=integration,
            start_time=start_time,
            messages=messages,
            response=response,
            old_iterator=response._iterator,
            finish_span=True,
        )
    else:
        _set_completions_api_output_data(
            span, response, kwargs, integration, finish_span=True
        )

    return response


def _set_completions_api_output_data(
    span: "Span",
    response: "Any",
    kwargs: "dict[str, Any]",
    integration: "OpenAIIntegration",
    finish_span: bool = True,
) -> None:
    messages = kwargs.get("messages")

    if messages is not None and isinstance(messages, str):
        messages = [messages]

    _set_common_output_data(
        span,
        response,
        messages,
        integration,
        finish_span,
    )


def _wrap_synchronous_completions_chunk_iterator(
    span: "Span",
    integration: "OpenAIIntegration",
    start_time: "Optional[float]",
    messages: "Optional[Iterable[ChatCompletionMessageParam]]",
    response: "Stream[ChatCompletionChunk]",
    old_iterator: "Iterator[ChatCompletionChunk]",
    finish_span: "bool",
) -> "Iterator[ChatCompletionChunk]":
    """
    Sets information received while iterating the response stream on the AI Client Span.
    Compute token count based on inputs and outputs using tiktoken if token counts are not in the model response.
    Responsible for closing the AI Client Span if instructed to by the `finish_span` argument.
    """
    ttft = None
    data_buf: "list[list[str]]" = []  # one for each choice
    streaming_message_total_token_usage = None

    for x in old_iterator:
        span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, x.model)

        with capture_internal_exceptions():
            if hasattr(x, "choices"):
                choice_index = 0
                for choice in x.choices:
                    if hasattr(choice, "delta") and hasattr(choice.delta, "content"):
                        if start_time is not None and ttft is None:
                            ttft = time.perf_counter() - start_time
                        content = choice.delta.content
                        if len(data_buf) <= choice_index:
                            data_buf.append([])
                        data_buf[choice_index].append(content or "")
                    choice_index += 1
            if hasattr(x, "usage"):
                streaming_message_total_token_usage = x.usage

        yield x

    with capture_internal_exceptions():
        if ttft is not None:
            set_data_normalized(
                span, SPANDATA.GEN_AI_RESPONSE_TIME_TO_FIRST_TOKEN, ttft
            )
        all_responses = None
        if len(data_buf) > 0:
            all_responses = ["".join(chunk) for chunk in data_buf]
            if should_send_default_pii() and integration.include_prompts:
                set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, all_responses)

        _calculate_completions_token_usage(
            messages=messages,
            response=response,
            span=span,
            streaming_message_responses=all_responses,
            streaming_message_total_token_usage=streaming_message_total_token_usage,
            count_tokens=integration.count_tokens,
        )

    if finish_span:
        span.__exit__(None, None, None)


async def _wrap_asynchronous_completions_chunk_iterator(
    span: "Span",
    integration: "OpenAIIntegration",
    start_time: "Optional[float]",
    messages: "Optional[Iterable[ChatCompletionMessageParam]]",
    response: "AsyncStream[ChatCompletionChunk]",
    old_iterator: "AsyncIterator[ChatCompletionChunk]",
    finish_span: "bool",
) -> "AsyncIterator[ChatCompletionChunk]":
    """
    Sets information received while iterating the response stream on the AI Client Span.
    Compute token count based on inputs and outputs using tiktoken if token counts are not in the model response.
    Responsible for closing the AI Client Span if instructed to by the `finish_span` argument.
    """
    ttft = None
    data_buf: "list[list[str]]" = []  # one for each choice
    streaming_message_total_token_usage = None

    async for x in old_iterator:
        span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, x.model)

        with capture_internal_exceptions():
            if hasattr(x, "choices"):
                choice_index = 0
                for choice in x.choices:
                    if hasattr(choice, "delta") and hasattr(choice.delta, "content"):
                        if start_time is not None and ttft is None:
                            ttft = time.perf_counter() - start_time
                        content = choice.delta.content
                        if len(data_buf) <= choice_index:
                            data_buf.append([])
                        data_buf[choice_index].append(content or "")
                    choice_index += 1
            if hasattr(x, "usage"):
                streaming_message_total_token_usage = x.usage

        yield x

    with capture_internal_exceptions():
        if ttft is not None:
            set_data_normalized(
                span, SPANDATA.GEN_AI_RESPONSE_TIME_TO_FIRST_TOKEN, ttft
            )
        all_responses = None
        if len(data_buf) > 0:
            all_responses = ["".join(chunk) for chunk in data_buf]
            if should_send_default_pii() and integration.include_prompts:
                set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, all_responses)

        _calculate_completions_token_usage(
            messages=messages,
            response=response,
            span=span,
            streaming_message_responses=all_responses,
            streaming_message_total_token_usage=streaming_message_total_token_usage,
            count_tokens=integration.count_tokens,
        )

    if finish_span:
        span.__exit__(None, None, None)


def _wrap_synchronous_responses_event_iterator(
    span: "Span",
    integration: "OpenAIIntegration",
    start_time: "Optional[float]",
    input: "Optional[Union[str, ResponseInputParam]]",
    response: "Stream[ResponseStreamEvent]",
    old_iterator: "Iterator[ResponseStreamEvent]",
    finish_span: "bool",
) -> "Iterator[ResponseStreamEvent]":
    """
    Sets information received while iterating the response stream on the AI Client Span.
    Compute token count based on inputs and outputs using tiktoken if token counts are not in the model response.
    Responsible for closing the AI Client Span if instructed to by the `finish_span` argument.
    """
    ttft = None
    data_buf: "list[list[str]]" = []  # one for each choice

    count_tokens_manually = True
    for x in old_iterator:
        with capture_internal_exceptions():
            if hasattr(x, "delta"):
                if start_time is not None and ttft is None:
                    ttft = time.perf_counter() - start_time
                if len(data_buf) == 0:
                    data_buf.append([])
                data_buf[0].append(x.delta or "")

            if isinstance(x, ResponseCompletedEvent):
                span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, x.response.model)

                _calculate_responses_token_usage(
                    input=input,
                    response=x.response,
                    span=span,
                    streaming_message_responses=None,
                    count_tokens=integration.count_tokens,
                )
                count_tokens_manually = False

        yield x

    with capture_internal_exceptions():
        if ttft is not None:
            set_data_normalized(
                span, SPANDATA.GEN_AI_RESPONSE_TIME_TO_FIRST_TOKEN, ttft
            )
        if len(data_buf) > 0:
            all_responses = ["".join(chunk) for chunk in data_buf]
            if should_send_default_pii() and integration.include_prompts:
                set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, all_responses)

            if count_tokens_manually:
                _calculate_responses_token_usage(
                    input=input,
                    response=response,
                    span=span,
                    streaming_message_responses=all_responses,
                    count_tokens=integration.count_tokens,
                )

    if finish_span:
        span.__exit__(None, None, None)


async def _wrap_asynchronous_responses_event_iterator(
    span: "Span",
    integration: "OpenAIIntegration",
    start_time: "Optional[float]",
    input: "Optional[Union[str, ResponseInputParam]]",
    response: "AsyncStream[ResponseStreamEvent]",
    old_iterator: "AsyncIterator[ResponseStreamEvent]",
    finish_span: "bool",
) -> "AsyncIterator[ResponseStreamEvent]":
    """
    Sets information received while iterating the response stream on the AI Client Span.
    Compute token count based on inputs and outputs using tiktoken if token counts are not in the model response.
    Responsible for closing the AI Client Span if instructed to by the `finish_span` argument.
    """
    ttft: "Optional[float]" = None
    data_buf: "list[list[str]]" = []  # one for each choice

    count_tokens_manually = True
    async for x in old_iterator:
        with capture_internal_exceptions():
            if hasattr(x, "delta"):
                if start_time is not None and ttft is None:
                    ttft = time.perf_counter() - start_time
                if len(data_buf) == 0:
                    data_buf.append([])
                data_buf[0].append(x.delta or "")

            if isinstance(x, ResponseCompletedEvent):
                span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, x.response.model)

                _calculate_responses_token_usage(
                    input=input,
                    response=x.response,
                    span=span,
                    streaming_message_responses=None,
                    count_tokens=integration.count_tokens,
                )
                count_tokens_manually = False

        yield x

    with capture_internal_exceptions():
        if ttft is not None:
            set_data_normalized(
                span, SPANDATA.GEN_AI_RESPONSE_TIME_TO_FIRST_TOKEN, ttft
            )
        if len(data_buf) > 0:
            all_responses = ["".join(chunk) for chunk in data_buf]
            if should_send_default_pii() and integration.include_prompts:
                set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, all_responses)
            if count_tokens_manually:
                _calculate_responses_token_usage(
                    input=input,
                    response=response,
                    span=span,
                    streaming_message_responses=all_responses,
                    count_tokens=integration.count_tokens,
                )
    if finish_span:
        span.__exit__(None, None, None)


def _set_responses_api_output_data(
    span: "Span",
    response: "Any",
    kwargs: "dict[str, Any]",
    integration: "OpenAIIntegration",
    finish_span: bool = True,
) -> None:
    input = kwargs.get("input")

    if input is not None and isinstance(input, str):
        input = [input]

    _set_common_output_data(
        span,
        response,
        input,
        integration,
        finish_span,
    )


def _set_embeddings_output_data(
    span: "Span",
    response: "Any",
    kwargs: "dict[str, Any]",
    integration: "OpenAIIntegration",
    finish_span: bool = True,
) -> None:
    input = kwargs.get("input")

    if input is not None and isinstance(input, str):
        input = [input]

    _set_common_output_data(
        span,
        response,
        input,
        integration,
        finish_span,
    )


def _wrap_chat_completion_create(f: "Callable[..., Any]") -> "Callable[..., Any]":
    def _execute_sync(f: "Any", *args: "Any", **kwargs: "Any") -> "Any":
        gen = _new_chat_completion_common(f, *args, **kwargs)

        try:
            f, args, kwargs = next(gen)
        except StopIteration as e:
            return e.value

        try:
            try:
                result = f(*args, **kwargs)
            except Exception as e:
                exc_info = sys.exc_info()
                with capture_internal_exceptions():
                    _capture_exception(e)
                reraise(*exc_info)

            return gen.send(result)
        except StopIteration as e:
            return e.value

    @wraps(f)
    def _sentry_patched_create_sync(*args: "Any", **kwargs: "Any") -> "Any":
        integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
        if integration is None or "messages" not in kwargs:
            # no "messages" means invalid call (in all versions of openai), let it return error
            return f(*args, **kwargs)

        return _execute_sync(f, *args, **kwargs)

    return _sentry_patched_create_sync


def _wrap_async_chat_completion_create(f: "Callable[..., Any]") -> "Callable[..., Any]":
    async def _execute_async(f: "Any", *args: "Any", **kwargs: "Any") -> "Any":
        gen = _new_chat_completion_common(f, *args, **kwargs)

        try:
            f, args, kwargs = next(gen)
        except StopIteration as e:
            return await e.value

        try:
            try:
                result = await f(*args, **kwargs)
            except Exception as e:
                exc_info = sys.exc_info()
                with capture_internal_exceptions():
                    _capture_exception(e)
                reraise(*exc_info)

            return gen.send(result)
        except StopIteration as e:
            return e.value

    @wraps(f)
    async def _sentry_patched_create_async(*args: "Any", **kwargs: "Any") -> "Any":
        integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
        if integration is None or "messages" not in kwargs:
            # no "messages" means invalid call (in all versions of openai), let it return error
            return await f(*args, **kwargs)

        return await _execute_async(f, *args, **kwargs)

    return _sentry_patched_create_async


def _new_embeddings_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "Any":
    integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
    if integration is None:
        return f(*args, **kwargs)

    model = kwargs.get("model")

    with sentry_sdk.start_span(
        op=consts.OP.GEN_AI_EMBEDDINGS,
        name=f"embeddings {model}",
        origin=OpenAIIntegration.origin,
    ) as span:
        span.set_data(SPANDATA.GEN_AI_SYSTEM, "openai")
        _set_embeddings_input_data(span, kwargs, integration)

        response = yield f, args, kwargs

        _set_embeddings_output_data(
            span, response, kwargs, integration, finish_span=False
        )

        return response


def _wrap_embeddings_create(f: "Any") -> "Any":
    def _execute_sync(f: "Any", *args: "Any", **kwargs: "Any") -> "Any":
        gen = _new_embeddings_create_common(f, *args, **kwargs)

        try:
            f, args, kwargs = next(gen)
        except StopIteration as e:
            return e.value

        try:
            try:
                result = f(*args, **kwargs)
            except Exception as e:
                exc_info = sys.exc_info()
                with capture_internal_exceptions():
                    _capture_exception(e, manual_span_cleanup=False)
                reraise(*exc_info)

            return gen.send(result)
        except StopIteration as e:
            return e.value

    @wraps(f)
    def _sentry_patched_create_sync(*args: "Any", **kwargs: "Any") -> "Any":
        integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
        if integration is None:
            return f(*args, **kwargs)

        return _execute_sync(f, *args, **kwargs)

    return _sentry_patched_create_sync


def _wrap_async_embeddings_create(f: "Any") -> "Any":
    async def _execute_async(f: "Any", *args: "Any", **kwargs: "Any") -> "Any":
        gen = _new_embeddings_create_common(f, *args, **kwargs)

        try:
            f, args, kwargs = next(gen)
        except StopIteration as e:
            return await e.value

        try:
            try:
                result = await f(*args, **kwargs)
            except Exception as e:
                exc_info = sys.exc_info()
                with capture_internal_exceptions():
                    _capture_exception(e, manual_span_cleanup=False)
                reraise(*exc_info)

            return gen.send(result)
        except StopIteration as e:
            return e.value

    @wraps(f)
    async def _sentry_patched_create_async(*args: "Any", **kwargs: "Any") -> "Any":
        integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
        if integration is None:
            return await f(*args, **kwargs)

        return await _execute_async(f, *args, **kwargs)

    return _sentry_patched_create_async


def _new_responses_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "Any":
    integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
    if integration is None:
        return f(*args, **kwargs)

    model = kwargs.get("model")

    span = sentry_sdk.start_span(
        op=consts.OP.GEN_AI_RESPONSES,
        name=f"responses {model}",
        origin=OpenAIIntegration.origin,
    )
    span.__enter__()

    span.set_data(SPANDATA.GEN_AI_SYSTEM, "openai")

    # Same bool handling as in https://github.com/openai/openai-python/blob/acd0c54d8a68efeedde0e5b4e6c310eef1ce7867/src/openai/resources/responses/responses.py#L940
    is_streaming_response = kwargs.get("stream", False) or False
    span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, is_streaming_response)

    _set_responses_api_input_data(span, kwargs, integration)

    start_time = time.perf_counter()
    response = yield f, args, kwargs

    # Attribute check to fail gracefully if the attribute is not present in future `openai` versions.
    if isinstance(response, Stream) and hasattr(response, "_iterator"):
        input = kwargs.get("input")

        if input is not None and isinstance(input, str):
            input = [input]

        response._iterator = _wrap_synchronous_responses_event_iterator(
            span=span,
            integration=integration,
            start_time=start_time,
            input=input,
            response=response,
            old_iterator=response._iterator,
            finish_span=True,
        )

    # Attribute check to fail gracefully if the attribute is not present in future `openai` versions.
    elif isinstance(response, AsyncStream) and hasattr(response, "_iterator"):
        input = kwargs.get("input")

        if input is not None and isinstance(input, str):
            input = [input]

        response._iterator = _wrap_asynchronous_responses_event_iterator(
            span=span,
            integration=integration,
            start_time=start_time,
            input=input,
            response=response,
            old_iterator=response._iterator,
            finish_span=True,
        )
    else:
        _set_responses_api_output_data(
            span, response, kwargs, integration, finish_span=True
        )

    return response


def _wrap_responses_create(f: "Any") -> "Any":
    def _execute_sync(f: "Any", *args: "Any", **kwargs: "Any") -> "Any":
        gen = _new_responses_create_common(f, *args, **kwargs)

        try:
            f, args, kwargs = next(gen)
        except StopIteration as e:
            return e.value

        try:
            try:
                result = f(*args, **kwargs)
            except Exception as e:
                exc_info = sys.exc_info()
                with capture_internal_exceptions():
                    _capture_exception(e)
                reraise(*exc_info)

            return gen.send(result)
        except StopIteration as e:
            return e.value

    @wraps(f)
    def _sentry_patched_create_sync(*args: "Any", **kwargs: "Any") -> "Any":
        integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
        if integration is None:
            return f(*args, **kwargs)

        return _execute_sync(f, *args, **kwargs)

    return _sentry_patched_create_sync


def _wrap_async_responses_create(f: "Any") -> "Any":
    async def _execute_async(f: "Any", *args: "Any", **kwargs: "Any") -> "Any":
        gen = _new_responses_create_common(f, *args, **kwargs)

        try:
            f, args, kwargs = next(gen)
        except StopIteration as e:
            return await e.value

        try:
            try:
                result = await f(*args, **kwargs)
            except Exception as e:
                exc_info = sys.exc_info()
                with capture_internal_exceptions():
                    _capture_exception(e)
                reraise(*exc_info)

            return gen.send(result)
        except StopIteration as e:
            return e.value

    @wraps(f)
    async def _sentry_patched_responses_async(*args: "Any", **kwargs: "Any") -> "Any":
        integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
        if integration is None:
            return await f(*args, **kwargs)

        return await _execute_async(f, *args, **kwargs)

    return _sentry_patched_responses_async


def _is_given(obj: "Any") -> bool:
    """
    Check for givenness safely across different openai versions.
    """
    if NotGiven is not None and isinstance(obj, NotGiven):
        return False
    if Omit is not None and isinstance(obj, Omit):
        return False
    return True
