# mypy: allow-untyped-defs
"""
Contains various utils for AOTAutograd, including those for handling collections.
"""

import copy
import dataclasses
import logging
import operator
import warnings
from collections.abc import Callable
from contextlib import nullcontext
from functools import wraps
from typing import Any, Optional, TypeVar, Union
from typing_extensions import ParamSpec

import torch
import torch.utils._pytree as pytree
from torch._library.fake_class_registry import FakeScriptObject
from torch._logging import getArtifactLogger
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import FunctionalTensor
from torch.fx.experimental._backward_state import BackwardState
from torch.fx.experimental.proxy_tensor import py_sym_types

from .descriptors import AOTOutput


KNOWN_TYPES = [
    torch.Tensor,
    BackwardState,
    int,
    str,
    float,
    bool,
    type(None),
    *py_sym_types,
    FakeScriptObject,
    torch.ScriptObject,
]

original_zip = zip

aot_graphs_effects_log = getArtifactLogger(__name__, "aot_graphs_effects")
annotation_log = getArtifactLogger(__name__, "annotation")


def strict_zip(*iterables, strict=True, **kwargs):
    if not strict:
        return original_zip(*iterables, **kwargs)

    length = len(iterables[0])
    for iterable in iterables[1:]:
        if len(iterable) != length:
            raise ValueError(
                "The iterables have different lengths and strict mode is enabled."
            )

    return original_zip(*iterables, **kwargs)


def _get_symint_hints(exprs):
    """
    Get the hints of a list/tuple of int/SymInt.
    """
    if isinstance(exprs, (list, tuple)):
        return type(exprs)(_get_symint_hints(e) for e in exprs)
    elif isinstance(exprs, torch.SymInt):
        return exprs.node.shape_env.size_hint(exprs.node.expr)
    else:
        return exprs


def partial_flatten_asdict(obj: Any) -> Any:
    if dataclasses.is_dataclass(obj):
        return {
            field.name: getattr(obj, field.name) for field in dataclasses.fields(obj)
        }
    elif isinstance(obj, (list, tuple)):
        return obj.__class__([partial_flatten_asdict(item) for item in obj])
    elif isinstance(obj, dict):
        return {k: partial_flatten_asdict(v) for k, v in obj.items()}
    else:
        return obj


def normalize_as_list(x):
    if isinstance(x, tuple):
        return list(x)
    elif isinstance(x, list):
        return x
    return [x]


def _get_autocast_states():
    return [
        torch.is_autocast_enabled("cuda"),
        torch.is_autocast_enabled("cpu"),
        torch.get_autocast_dtype("cuda"),
        torch.get_autocast_dtype("cpu"),
        torch.is_autocast_cache_enabled(),
    ]


def make_boxed_func(f):
    @simple_wraps(f)
    def g(args):
        return f(*args)

    g._boxed_call = True  # type: ignore[attr-defined]
    return g


def make_boxed_compiler(compiler):
    @wraps(compiler)
    def f(fx_g, inps):
        out_f = compiler(fx_g, inps)
        fx_g = make_boxed_func(out_f)
        return fx_g

    return f


def call_func_at_runtime_with_args(
    f, args: Union[tuple[Any], list[Any]], steal_args=False, disable_amp=False
):
    if not steal_args:
        args = list(args)
    assert isinstance(args, list)

    context = torch._C._DisableAutocast if disable_amp else nullcontext
    with context():
        if getattr(f, "_boxed_call", False):
            out = normalize_as_list(f(args))
        else:
            # TODO: Please remove soon
            # https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670
            warnings.warn(
                "Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. "
                "Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. "
                "See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.",
                stacklevel=2,
            )
            out = normalize_as_list(f(*args))
    return out


# Inspired by autodidax (thanks!)
class PytreeThunk:
    spec: Optional[pytree.TreeSpec] = None
    # These are some kinda dumb microoptimizations that save about 3-4 us of overhead.
    is_simple: Optional[bool] = (
        None  # if the output spec is a tuple/list, we won't bother unflattening it.
    )
    is_really_simple: Optional[bool] = None  # if the output spec is a LeafSpec

    def set(self, spec: pytree.TreeSpec) -> None:
        assert self.spec is None or self.spec == spec
        assert spec is not None
        self.spec: pytree.TreeSpec = spec
        if self.spec.type in {tuple, list} and all(
            child.is_leaf() for child in spec.children()
        ):
            self.is_simple = True
        if self.spec.is_leaf():
            self.is_really_simple = True

    def unflatten(self, x: list[Any]) -> Any:
        if self.is_really_simple:
            return x[0]
        if self.is_simple:
            return x
        assert self.spec is not None
        return pytree.tree_unflatten(x, self.spec)


# Creates a function that returns flattened inputs and outputs
# Also returns the output tree spec, which is needed to recover the "unflattened"
# output tree structure later.
def create_tree_flattened_fn(fn, args, kwargs=None) -> tuple[Callable, PytreeThunk]:
    if kwargs is None:
        kwargs = {}
    # Save the args_spec for flat_tensor_args to unflatten while tracing
    _, tensor_args_spec = pytree.tree_flatten((args, kwargs))
    out_spec = PytreeThunk()

    def flat_fn(*flat_args):
        # The input are flattened tensor args. Prepare the args in the
        # order that original function expects. Add static args as well.
        # They will appear as tensor constants in the traced graph.
        nonlocal out_spec
        args, kwargs = pytree.tree_unflatten(flat_args, tensor_args_spec)
        tree_out = fn(*args, **kwargs)
        flat_out, spec = pytree.tree_flatten(tree_out)
        for i in flat_out:
            is_known_type = False
            for j in KNOWN_TYPES:
                if isinstance(i, j):
                    is_known_type = True
                    break
            if not is_known_type:
                raise RuntimeError(
                    f"Found {type(i)} in output, which is not a known type. "
                    "If this type holds tensors, you need to register a pytree for it. "
                    "See https://github.com/pytorch/functorch/issues/475 for a brief "
                    "explanation why. If you don't need to register a pytree, please "
                    "leave a comment explaining your use case and we'll make this more "
                    "ergonomic to deal with"
                )
        out_spec.set(spec)
        return flat_out

    # Can't use functools.wraps here because the wrapper has different
    # calling convention
    if hasattr(fn, "_orig_mod"):
        flat_fn._orig_mod = fn._orig_mod  # type: ignore[attr-defined]

    return flat_fn, out_spec


# This function takes in a tensor t, and returns one of t, t.view(), or t.clone().
# When tracing the joint forward + backward, for any inputs in the graph that are mutated,
# we need to clone them first (and similarly for metadata-only mutations, we need to view them first).
# The idea is that when we trace the backward, we need to pass in the *original* primals
# to autograd.grad(), before they were mutated.
# Note: when we have synthetic base inputs, we need to clone them *before* creating views off of them.
# This means that "idx" here represents the index of the (potentially) synthetic base.
# What we need to do is:
# (1) map the current (post-synthetic-base calling convention) input argument index
#     to int index pre-synthetic-base-calling-convention.
# (2) There could be multiple, if this index corresponds to a synthetic base
#     that has multiple input aliases.
# (3) If any of those corresponding inputs get metadata mutations, then we clone the base.
def maybe_to_fresh_input(idx, t, meta):
    if not isinstance(t, torch.Tensor):
        return t
    if idx in meta.mutated_inp_runtime_indices:
        # We only need to bother cloning mutated inputs that participate in autograd.
        if meta.input_info[idx].requires_grad and meta.input_info[idx].mutates_data:
            # Make sure the primal we pass to autograd.grad()
            # sees the tensor before the mutation
            return t.clone()
        if meta.input_info[idx] and meta.input_info[idx].mutates_metadata:
            # Make sure the primal we pass to autograd.grad()
            # sees the tensor before the metadata mutation
            return t.view(t.shape)
    return t


def is_with_effects(node):
    if (
        node.op == "call_function"
        and node.target is torch.ops.higher_order.with_effects
    ):
        return True
    elif (
        node.op == "call_function"
        and node.target is torch.ops.higher_order.invoke_subgraph
    ):
        # Check if subgraph has effects by looking in the cache
        from torch._guards import InvokeSubgraphCache, TracingContext

        tracing_ctx = TracingContext.try_get()
        if tracing_ctx:
            invoke_subgraph_cache = tracing_ctx.hop_dispatch_set_cache.get_cache(
                torch.ops.higher_order.invoke_subgraph
            )
            if invoke_subgraph_cache:
                assert isinstance(invoke_subgraph_cache, InvokeSubgraphCache)
                effects = invoke_subgraph_cache.get_effects(node.args[1])
                return effects is not None
    return False


def unlift_tokens(fw_module, fw_metadata, aot_config, bw_module=None):
    # Remove the tokens from the inputs/outputs of the graph since inductor does
    # not want these extra inputs/outputs, and replace them with
    # _make_token() to create a token, and _sink_tokens() to collect the
    # tokens.  See Note [Side-Effectful Tokens in AOTAutograd]
    # Logic:
    # 1. In the case of with_effects:
    #   Before:
    #   ```
    #   def forward(self, token, arg1_1):
    #       with_effects = torch.ops.higher_order.with_effects(token, ...)
    #       getitem = with_effects[0]
    #       getitem_1 = with_effects[0]
    #       return (getitem, getitem_1)
    #   ```
    #
    #   After:
    #   ```
    #   def forward(self, arg1_1):
    #       _make_token_default = torch.ops.prims._make_token.default()
    #       with_effects = torch.ops.higher_order.with_effects(_make_token_default, ...)
    #       getitem = with_effects[0]
    #       getitem_1 = with_effects[0]
    #       _sink_tokens_default = torch.ops.prims._sink_tokens.default([getitem]);
    #       return (getitem_1,)
    #   ```
    #
    # 2. In the case of an invoke_subgraph node, we will use the
    # InvokeSubgraphCache to determine if the subgraph has effects. Then we will
    # turn it into a `with_effects` node. This is so that at the toplevel graph,
    # the nodes will have the correct with_effects threading. We will apply this
    # pass recursively to submodules so the tokens will be removed from the
    # subgraph's inputs.
    #
    #   Before:
    #   ```
    #   def forward(self, token, arg1_1):
    #       repeated_subgraph0 = self.repeated_subgraph0
    #       invoke_subgraph = torch.ops.higher_order.invoke_subgraph(
    #           repeated_subgraph0, 'subgraph_0', token, x, arg1_1)
    #       getitem = invoke_subgraph[0]
    #       getitem_1 = invoke_subgraph[1]
    #       return (getitem, getitem1)
    #   ```
    #
    #   After:
    #   ```
    #   def forward(self, arg1_1):
    #       _make_token_default = torch.ops.prims._make_token.default()
    #       repeated_subgraph0 = self.repeated_subgraph0
    #       with_effects_1 = torch.ops.higher_order.with_effects(
    #           _make_token_default, torch.ops.higher_order.invoke_subgraph,
    #           repeated_subgraph0, 'subgraph_0', arg1_1)
    #       getitem = with_effects_1[0]
    #       getitem_1 = with_effects_1[1];  with_effects_1 = None
    #       _sink_tokens_default = torch.ops.prims._sink_tokens.default([getitem])
    #       return (getitem_1,)
    #   ```
    #
    # 3. The toplevel module should have the following invariants:
    #   forward:
    #     expected_num_erased_inputs == len(fw_metadata.tokens)
    #     expected_num_erased_outputs == len(fw_metadata.tokens)
    #   backward:
    #     expected_num_erased_inputs == fw_metadata.num_backward_tokens
    #     expected_num_erased_outputs == fw_metadata.num_backward_tokens
    num_forward_tokens = len(fw_metadata.tokens)
    num_backward_tokens = fw_metadata.num_backward_tokens

    def replace_input_token_with_make_token(module, node):
        with module.graph.inserting_before(node):
            new_token_node = module.graph.call_function(
                torch.ops.prims._make_token.default, ()
            )
            new_token_node.meta["val"] = torch.tensor([])
            new_token_node.meta["tensor_meta"] = torch.tensor([])
            node.replace_all_uses_with(new_token_node)
            module.graph.erase_node(node)

    def get_output_tokens(node: torch.fx.Node) -> set[torch.fx.Node]:
        output_tokens = set()
        for user in list(node.users.keys()):
            # Check if this is a getitem accessing index 0 (the token)
            if (
                user.op == "call_function"
                and user.target is operator.getitem
                and len(user.args) > 1
                and user.args[1] == 0
            ):
                # Check if this getitem is used in an output
                for user_user in list(user.users.keys()):
                    if user_user.op == "output":
                        output_tokens.add(user)
        return output_tokens

    def _unlift_tokens_from_module_helper(
        module: torch.fx.GraphModule,
        subgraph_str: str,
        expected_num_erased: Optional[int],
    ):
        input_token_nodes = set()
        output_token_nodes = set()

        for node in module.graph.nodes:
            if (
                node.op == "call_function"
                and node.target is torch.ops.higher_order.with_effects
            ):
                if node.args[0].op == "placeholder":
                    input_token_nodes.add(node.args[0])
                    replace_input_token_with_make_token(module, node.args[0])

                tokens_from_with_effects = get_output_tokens(node)
                output_token_nodes = output_token_nodes | tokens_from_with_effects

            elif (
                node.op == "call_function"
                and node.target is torch.ops.higher_order.invoke_subgraph
            ):
                subgraph_node, identifier, *operands = node.args

                # Check if subgraph has effects by looking in the cache
                from torch._guards import InvokeSubgraphCache, TracingContext

                effects = None
                tracing_ctx = TracingContext.try_get()
                if tracing_ctx:
                    invoke_subgraph_cache = (
                        tracing_ctx.hop_dispatch_set_cache.get_cache(
                            torch.ops.higher_order.invoke_subgraph
                        )
                    )
                    if invoke_subgraph_cache:
                        assert isinstance(invoke_subgraph_cache, InvokeSubgraphCache)
                        effects = invoke_subgraph_cache.get_effects(identifier)

                if effects is not None:
                    # Wrap invoke_subgraph with with_effects
                    # Before: invoke_subgraph(subgraph, id, token, *args) -> (token_out, result)
                    # After: with_effects(token, invoke_subgraph, subgraph, id, *args) -> (token_out, result)
                    #
                    # Note: The subgraph itself will be unlifted separately when we iterate
                    # through named_modules() below.

                    num_tokens = len(effects)
                    assert num_tokens == 1, "Multiple token subgraph NYI"
                    token_args = operands[:num_tokens]
                    non_token_args = operands[num_tokens:]

                    # Create with_effects wrapper around invoke_subgraph
                    # with_effects(token, op, *args) where op is invoke_subgraph
                    # Pass the subgraph and non-token args to invoke_subgraph
                    with module.graph.inserting_before(node):
                        new_node = module.graph.call_function(
                            torch.ops.higher_order.with_effects,
                            (
                                token_args[0],  # pyrefly: ignore[bad-argument-type]
                                torch.ops.higher_order.invoke_subgraph,
                                subgraph_node,
                                identifier,
                                *tuple(non_token_args),
                            ),
                        )
                        node.replace_all_uses_with(new_node)
                        new_node.meta = node.meta
                        module.graph.erase_node(node)

                    for token in token_args:
                        if token.op == "placeholder":
                            input_token_nodes.add(token)
                            replace_input_token_with_make_token(module, token)

                    # Get output tokens from the new with_effects node
                    tokens_from_invoke_subgraph = get_output_tokens(new_node)
                    output_token_nodes = (
                        output_token_nodes | tokens_from_invoke_subgraph
                    )

        output_node = next(reversed(module.graph.find_nodes(op="output")))
        assert output_node is not None
        with module.graph.inserting_before(output_node):
            module.graph.call_function(
                torch.ops.prims._sink_tokens.default,
                (list(output_token_nodes),),
            )
        new_out_args = tuple(
            [out for out in output_node.args[0] if out not in output_token_nodes]
        )
        output_node.args = (new_out_args,)

        if expected_num_erased:
            assert len(input_token_nodes) == expected_num_erased, (
                f"{subgraph_str} num_erased_inputs:{len(input_token_nodes)} "
                f"{input_token_nodes} != expected {expected_num_erased} \n"
                f"{fw_module.print_readable(print_output=False)}"
            )
            assert len(output_token_nodes) == expected_num_erased, (
                f"{subgraph_str} num_erased_outs:{len(output_token_nodes)} "
                f"{output_token_nodes} != expected {expected_num_erased} \n"
                f"{fw_module.print_readable(print_output=False)}"
            )

        module.recompile()

    def unlift_tokens_from_module(module, subgraph_str, expected_num_erased):
        for name, m in module.named_modules():
            if isinstance(m, torch.fx.GraphModule):
                if name == "":
                    _unlift_tokens_from_module_helper(
                        m, subgraph_str, expected_num_erased
                    )
                else:
                    # Subgraph -- we may or may not have effects applied
                    _unlift_tokens_from_module_helper(m, f"{subgraph_str}_{name}", None)

    if num_forward_tokens > 0:
        if aot_config.enable_log:
            from torch._dynamo.utils import lazy_format_graph_code

            aot_graphs_effects_log.debug(
                "%s",
                lazy_format_graph_code(
                    "Forward graph before unlifting tokens",
                    fw_module,
                    aot_config.aot_id,
                    include_stride=True,
                    include_device=True,
                    colored=True,
                ),
            )
        unlift_tokens_from_module(
            fw_module,
            "forward",
            num_forward_tokens,
        )

    if bw_module is not None and num_backward_tokens > 0:
        if aot_config.enable_log:
            from torch._dynamo.utils import lazy_format_graph_code

            aot_graphs_effects_log.debug(
                "%s",
                lazy_format_graph_code(
                    "Backward graph before unlifting tokens",
                    bw_module,
                    aot_config.aot_id,
                    include_stride=True,
                    include_device=True,
                    colored=True,
                ),
            )
        unlift_tokens_from_module(bw_module, "backward", num_backward_tokens)

    # This is sad, but we need to update the metadata to get rid of
    # the tokens.
    fw_metadata.tokens = {}
    fw_metadata.num_backward_tokens = 0


def root_module_when_exporting_non_strict(flat_fn):
    # When exporting in non-strict mode, we wrap the root module in a specific pattern.
    # See `_aot_export_non_strict` in torch.export._trace.py.
    # We look for that wrapping pattern here.
    if hasattr(flat_fn, "_orig_mod") and hasattr(flat_fn._orig_mod, "_export_root"):
        return flat_fn._orig_mod._export_root
    else:
        return None


def _is_forward_node_with_seq_nr(node: torch.fx.Node) -> bool:
    # For now, assume that if nn_module_stack_metadata is populated, this
    # node is from the forward. Ignore nodes without `seq_nr`.
    # TODO(future): there is likely a less brittle way to do this by walking
    # the descendants of graph inputs corresponding to fwd inputs, didn't
    # seem obvious at first glance on how to partition graph inputs into
    # fwd vs bwd without relying on string names.
    return node.meta.get("partitioner_tag") != "is_backward" and "seq_nr" in node.meta


def _is_backward_node_with_seq_nr(node: torch.fx.Node) -> bool:
    # For now, assume that if nn_module_stack_metadata is not populated,
    # this node is from the backward. Ignore nodes without `seq_nr`.
    # TODO(future): there is likely a less brittle way to do this, same
    # as with the forward.
    return node.meta.get("partitioner_tag") == "is_backward" and "seq_nr" in node.meta


def _collect_fwd_nodes_from_subgraph(
    fx_g: torch.fx.GraphModule, fwd_seq_nr_to_node: dict[str, torch.fx.Node]
) -> None:
    """Collect forward nodes from a single subgraph into the global mapping."""
    for node in fx_g.graph.nodes:
        if not _is_forward_node_with_seq_nr(node):
            continue
        seq_nr = node.meta["seq_nr"]
        if seq_nr in fwd_seq_nr_to_node:
            # If we already saw an op with the current `seq_nr`, that means
            # that the current op did not create an autograd node, and there
            # is no corresponding backward node, so we skip.
            continue
        fwd_seq_nr_to_node[seq_nr] = node


def _copy_metadata_to_bw_nodes_in_subgraph(
    fx_g: torch.fx.GraphModule, fwd_seq_nr_to_node: dict[str, torch.fx.Node]
) -> None:
    """Copy metadata from forward nodes to backward nodes in a single subgraph."""
    for node in fx_g.graph.nodes:
        annotation_log.debug("node: %s", node.name)
        seq_nr = node.meta.get("seq_nr")
        annotation_log.debug("seq_nr: %s", seq_nr)

        if not _is_backward_node_with_seq_nr(node):
            continue

        # We exclude gradient accumulation nodes from copying tags
        if node.meta.get("is_gradient_acc", False):
            annotation_log.debug("is_gradient_acc")
            continue

        # fwd_node should always exist, but handle non-existence just in case
        fwd_node = fwd_seq_nr_to_node.get(node.meta["seq_nr"])
        if fwd_node is not None:
            node.meta["fwd_nn_module_stack"] = fwd_node.meta.get("nn_module_stack")
            node.meta["fwd_source_fn_stack"] = fwd_node.meta.get("source_fn_stack")
            # TODO: better to change to a specific field of custom?
            custom = fwd_node.meta.get("custom")
            if custom is not None:
                node.meta["custom"] = copy.deepcopy(custom)


def copy_fwd_metadata_to_bw_nodes(fx_g: torch.fx.GraphModule) -> None:
    """
    Input: `fx_g` which contains the joint fwd+bwd FX graph created by
    aot_autograd.

    This function walks the graph and copies over metadata from forward nodes
    to backward nodes, using the `seq_nr` field as a one-to-many mapping
    from forward node to backward node. This metadata is useful for performance
    profiling and debugging.

    This function supports matching forward and backward nodes across different
    subgraphs (e.g., in recursive submodules from HOPs), enabling backward nodes
    in any submodule to match forward nodes in any submodule.
    """

    # Build a global mapping of seq_nr to forward nodes across all subgraphs
    fwd_seq_nr_to_node: dict[str, torch.fx.Node] = {}

    # First pass: collect all forward nodes from all subgraphs
    for submod in fx_g.modules():
        if isinstance(submod, torch.fx.GraphModule):
            _collect_fwd_nodes_from_subgraph(submod, fwd_seq_nr_to_node)

    if annotation_log.isEnabledFor(logging.DEBUG):
        for k, v in fwd_seq_nr_to_node.items():
            annotation_log.debug("forward:: key: %s, value: %s", k, v)

    # Second pass: copy metadata to backward nodes in all subgraphs
    # using the global forward mapping
    for submod in fx_g.modules():
        if isinstance(submod, torch.fx.GraphModule):
            _copy_metadata_to_bw_nodes_in_subgraph(submod, fwd_seq_nr_to_node)


def register_buffer_assignment_hook(mod, assigned_buffers):
    """
    Register a hook that intercepts buffer assignments.
    This is used to detect when a buffer is assigned to, and then we can
    map that buffer to the corresponding proxy node in the graph.
    """

    def _map_assigned_buffer_to_proxy(_mod, name, buffer):
        # We intercept buffer assignments on the root module through this hook.
        if _mod._buffers is mod._buffers:
            # either buffer is a functional tensor, which wraps a fake tensor
            if isinstance(buffer, FunctionalTensor):
                buffer = buffer.from_functional()
            # or buffer is a fake tensor
            assert isinstance(buffer, FakeTensor)
            # The fake tensor in turn is associated with a proxy node.
            proxy_mode = torch.fx.experimental.proxy_tensor.get_proxy_mode()
            assert proxy_mode is not None
            proxy = torch.fx.experimental.proxy_tensor.get_proxy_slot(
                buffer, proxy_mode.tracer
            ).proxy.node
            # We map the assigned buffer to this proxy node.
            assigned_buffers[name] = proxy.name
        return buffer

    return torch.nn.modules.module.register_module_buffer_registration_hook(
        _map_assigned_buffer_to_proxy
    )


def contain_metadata_mutation_ops(module: torch.fx.GraphModule) -> bool:
    """
    Checks if the module contains any metadata mutation ops.
    """
    for node in module.graph.nodes:
        if (
            node.op == "call_function"
            and hasattr(node.target, "tags")
            and torch.Tag.inplace_view in node.target.tags
        ):
            return True
    return False


def get_cuda_generator_meta_val(device_idx: int):
    """
    Get a generator value to use as a meta val

    newly cloned generator will not contain tensors. it is only Generators that are
    registered to a CUDAGraph that contain tensors. since this does not contain Tensor
    it is fine to use in the meta.
    """
    return torch.cuda.default_generators[device_idx].clone_state()


def top_saved_tensors_hooks():
    return torch._C._autograd._top_saved_tensors_default_hooks(True)


def saved_tensors_hooks_are_inlineable(hooks) -> bool:
    if not hooks:
        return False
    pack, unpack = hooks
    return isinstance(pack, torch.fx.GraphModule) and isinstance(
        unpack, torch.fx.GraphModule
    )


_P = ParamSpec("_P")
_T = TypeVar("_T")
_S = TypeVar("_S")


def without_output_descs(f: Callable[_P, tuple[_T, _S]]) -> Callable[_P, _T]:
    @wraps(f)
    @simple_wraps(f)
    def inner(*args, **kwargs):
        # pyrefly: ignore [invalid-param-spec]
        return f(*args, **kwargs)[0]

    # pyrefly: ignore [bad-return]
    return inner


_P2 = ParamSpec("_P2")
_R = TypeVar("_R")
_R2 = TypeVar("_R2")


def simple_wraps(
    f: Callable[_P, _R],
) -> Callable[[Callable[_P2, _R2]], Callable[_P2, _R2]]:
    # NB: omit ('__module__', '__name__', '__qualname__') for ease of
    # debugging
    return wraps(f, assigned=("__doc__", "__annotations__", "__type_params__"))


def call_and_expect_output_descs(fn, args):
    outs_pair = fn(*args)
    assert isinstance(outs_pair, tuple) and len(outs_pair) == 2, (fn, outs_pair)
    outs, outs_descs = outs_pair
    # The Tensor tests protects against the test when there are no outputs
    out_vals, out_spec = pytree.tree_flatten(outs)
    out_desc_vals, out_desc_spec = pytree.tree_flatten(outs_descs)
    assert out_spec == out_desc_spec, (
        fn_wrappers(fn),
        outs,
        outs_descs,
        out_spec,
        out_desc_spec,
    )
    assert not any(isinstance(x, AOTOutput) for x in out_vals), (
        fn_wrappers(fn),
        outs,
        outs_descs,
        out_vals,
    )
    assert all(
        isinstance(d, AOTOutput)
        for (x, d) in zip(out_vals, out_desc_vals)
        if isinstance(x, (torch.Tensor, torch.SymInt)) or type(x) is int
    ), (fn_wrappers(fn), outs, outs_descs, out_vals, out_desc_vals)
    return outs_pair


def fn_wrappers(fn):
    fns = [fn]
    f = fn
    while hasattr(f, "__wrapped__"):
        f = f.__wrapped__
        fns.append(f)
    return fns
