import collections
import logging
import operator
from collections import defaultdict
from collections.abc import Callable
from typing import Any, Literal, TypeAlias

import torch
import torch.distributed as dist
import torch.utils._pytree as pytree
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.utils import detect_fake_mode
from torch._inductor.comm_analysis import (
    get_collective_type_from_kernel_name,
    NCCL_COLL,
)
from torch._inductor.runtime.runtime_utils import dynamo_timed
from torch._logging import trace_structured
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.traceback import NodeSource, NodeSourceAction
from torch.utils._ordered_set import OrderedSet


logger: logging.Logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

overlap_log = torch._logging.getArtifactLogger(__name__, "overlap")

BucketMode: TypeAlias = Literal["default", "custom_ops", "custom_ops_multidtype"]


# Helper functions moved to top for better organization
def _ag_group_key(node: torch.fx.Node) -> tuple[str, torch.dtype]:  # type: ignore[name-defined]
    _, group_size, group_name = node.args
    dtype = node.meta["val"].dtype
    assert isinstance(group_name, str)
    return (group_name, dtype)


def _ag_group_key_multidtype(node: torch.fx.Node) -> tuple[str]:
    _, group_size, group_name = node.args
    assert isinstance(group_name, str)
    return (group_name,)


def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:  # type: ignore[name-defined]
    _, reduce_op, group_size, group_name = node.args
    dtype = node.meta["val"].dtype
    assert isinstance(group_name, str)
    assert isinstance(reduce_op, str)
    return (group_name, reduce_op, dtype)


def _ar_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
    _, reduce_op, group_name = node.args
    dtype = node.meta["val"].dtype
    assert isinstance(group_name, str)
    assert isinstance(reduce_op, str)
    return (group_name, reduce_op, dtype)


def _schedulable_wait_node(node: torch.fx.Node) -> bool:
    """
    Add additional check on if the wait node is schedulable
    We should not schedule a fx node that is:
        1. wait on a collective that is not callable
        2. wait on a non-NCCL communication node
    """
    if not is_wait_tensor(node):
        return False
    assert isinstance(node.args[0], torch.fx.Node)
    if not isinstance(node.args[0].target, Callable):
        return False
    is_callable: bool = node.args[0].op == "call_function"
    coll: NCCL_COLL = get_collective_type_from_kernel_name(node.args[0].target.name())
    is_collective: bool = coll != NCCL_COLL.UNSUPPORTED
    return is_callable and is_collective


def _populate_node_meta(
    bucket_nodes: list[torch.fx.Node], new_nodes: list[torch.fx.Node]
):
    if bucket_nodes:
        for n in new_nodes:
            # For the following keys, we only store the information of the first node so
            # gm.print_readable shows some information
            # Full information are stored in "bucketing_{key}_sources"
            for key, default in [
                ("nn_module_stack", ""),
                ("fwd_nn_module_stack", ""),
                ("stack_trace", ""),
                ("custom", {}),
            ]:
                n.meta[key] = bucket_nodes[0].meta.get(key, default)

                # Collect sources from all bucket nodes for this metadata key, for debugging purposes only
                bucketing_sources_key = f"bucketing_{key}_sources"
                # Use set to remove duplicates
                if key == "stack_trace":
                    sources = OrderedSet(
                        [
                            node.meta.get(key, default)
                            for node in bucket_nodes
                            if node.meta.get(key, default)
                        ]
                    )
                else:
                    # type might not be hashable
                    sources = [
                        node.meta.get(key, default)
                        for node in bucket_nodes
                        if node.meta.get(key, default)
                    ]
                n.meta[bucketing_sources_key] = sources

            # used by inductor provenance tracking
            n.meta["from_node"] = [
                NodeSource(
                    original_node,
                    "bucketing_pass",
                    [NodeSourceAction.CREATE, NodeSourceAction.REPLACE],
                )
                for original_node in bucket_nodes
            ]


def bucket_key(node: torch.fx.Node, mode: BucketMode | None = None) -> object | None:
    if is_all_gather_into_tensor(node):
        group_key_fn = (
            _ag_group_key_multidtype if mode and "multidtype" in mode else _ag_group_key
        )
        return group_key_fn(node)
    elif is_reduce_scatter_tensor(node):
        return _rs_group_key(node)
    elif is_all_reduce_tensor(node):
        return _ar_group_key(node)
    else:
        return None


def pick_bucket_dtype(dtypes: list[torch.dtype]) -> torch.dtype:  # type: ignore[name-defined]
    assert len(dtypes) > 0
    return min(dtypes, key=operator.attrgetter("itemsize"))


def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float:
    """
    Determine the size of a bucket based on its ID.

    Args:
    bucket_id (int): The ID of the bucket.

    Returns:
    float: The size of the bucket.
    """
    return 2000.0


def bucket_all_gather(
    gm: torch.fx.GraphModule,
    bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
    mode: BucketMode = "default",
) -> None:
    if bucket_cap_mb_by_bucket_idx is None:
        from torch._inductor.fx_passes.bucketing import (  # pyrefly: ignore  # missing-module-attribute
            bucket_cap_mb_by_bucket_idx_default,
        )

        bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
    ag_buckets = bucket_all_gather_by_mb(gm, bucket_cap_mb_by_bucket_idx, None, mode)
    if len(ag_buckets) == 0:
        return
    merge_all_gather(gm, ag_buckets, mode)


def bucket_reduce_scatter(
    gm: torch.fx.GraphModule,
    bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
    mode: BucketMode = "default",
) -> None:
    if bucket_cap_mb_by_bucket_idx is None:
        from torch._inductor.fx_passes.bucketing import (  # pyrefly: ignore  # missing-module-attribute
            bucket_cap_mb_by_bucket_idx_default,
        )

        bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
    rs_buckets = bucket_reduce_scatter_by_mb(
        gm, bucket_cap_mb_by_bucket_idx, None, mode
    )
    if len(rs_buckets) == 0:
        return
    merge_reduce_scatter(gm, rs_buckets, mode)


def is_all_gather_into_tensor(node: torch.fx.Node) -> bool:  # type: ignore[arg-type]
    return node.op == "call_function" and (
        node.target == torch.ops._c10d_functional.all_gather_into_tensor.default
        or node.target == torch.ops._c10d_functional.all_gather_into_tensor_out.default
    )


def is_reduce_scatter_tensor(node: torch.fx.Node) -> bool:
    return (
        node.op == "call_function"
        and node.target is torch.ops._c10d_functional.reduce_scatter_tensor.default
    )


def is_wait_tensor(node: torch.fx.Node) -> bool:
    return (
        node.op == "call_function"
        and node.target is torch.ops._c10d_functional.wait_tensor.default
    )


def is_all_reduce_tensor(node: torch.fx.Node) -> bool:
    return (
        node.op == "call_function"
        and node.target is torch.ops._c10d_functional.all_reduce.default
    )


def is_all_to_all_tensor(node: torch.fx.Node) -> bool:
    return (
        node.op == "call_function"
        and node.target is torch.ops._c10d_functional.all_to_all_single.default
    )


def is_wait_tensor_from_all_gather_into_tensor(node: torch.fx.Node) -> bool:
    return is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0])  # type: ignore[arg-type]


def collect_node_descendants(
    graph: torch.fx.Graph,
) -> dict[torch.fx.Node, OrderedSet[torch.fx.Node]]:
    """
    Collects the descendants of each node in the graph.
    Args:
        graph (torch.fx.Graph): The graph to collect descendants from.
    Returns:
        dict[torch.fx.Node, OrderedSet[torch.fx.Node]]: A dictionary mapping each node to its descendants.
    """
    node_descendants: dict[torch.fx.Node, OrderedSet[torch.fx.Node]] = (
        collections.defaultdict(OrderedSet)
    )
    outdegree = collections.defaultdict(int)
    queue = []

    for node in graph.nodes:
        n_outdegree = len(node.users)
        if n_outdegree == 0:
            queue.append(node)
        else:
            outdegree[node] = len(node.users)

    while queue:
        node = queue.pop()
        for input_node in node.all_input_nodes:
            node_descendants[input_node] |= node_descendants[node]
            node_descendants[input_node].add(node)
            outdegree[input_node] -= 1

            if outdegree[input_node] == 0:
                queue.append(input_node)

    return node_descendants


def greedy_bucket_collective_by_mb(
    gm: torch.fx.GraphModule,
    bucket_cap_mb_by_bucket_idx: Callable[[int], float],
    filter_node: Callable[[torch.fx.Node], bool],
    node_group_key: Callable[[torch.fx.Node], Any],
    filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
) -> list[list[torch.fx.Node]]:
    """
    Bucketing adjacent collectives with equal node_group_key.
    We can not bucket non adjacent collectives,
    as this will effectively change the order of collectives.
    Reordering can lead to different order on different ranks.
    """
    g = gm.graph
    found_candidates = False
    for node in g.nodes:
        if filter_node(node):
            found_candidates = True
            break
    if not found_candidates:
        return []

    # TODO: pearce kelly algorithm for detecting cycles
    node_descendents = collect_node_descendants(gm.graph)

    nodes_groups: list[list[torch.fx.Node]] = []
    cur_group: list[torch.fx.Node] = []
    cur_group_key = None

    for node in g.nodes:
        if is_wait_tensor(node) and filter_node(node.args[0]):
            if (filter_wait_node is None) or filter_wait_node(node):
                coll_node = node.args[0]
                group_key = node_group_key(coll_node)
                if group_key == cur_group_key:
                    cur_group.append(coll_node)
                else:
                    if len(cur_group) > 1:
                        nodes_groups.append(cur_group)
                    cur_group = [coll_node]
                    cur_group_key = group_key

    if len(cur_group) > 1:
        nodes_groups.append(cur_group)

    buckets: list[list[torch.fx.Node]] = []
    for nodes in nodes_groups:
        cur_bucket: list[torch.fx.Node] = []
        cur_bucket_descendents: OrderedSet[torch.fx.Node] = OrderedSet()
        cur_bucket_size_bytes: int = 0
        cur_bucket_id: int = 0
        bucket_size_bytes = int(
            bucket_cap_mb_by_bucket_idx(cur_bucket_id) * 1024 * 1024
        )
        for node in nodes:
            if node in cur_bucket_descendents:
                # if there is a path from node to the current bucket, we cannot horizontally fuse (bucket)
                continue
            assert "val" in node.meta
            n_val = node.meta["val"]
            out_size_bytes = n_val.numel() * n_val.element_size()
            n_input_val = node.all_input_nodes[0].meta["val"]
            in_size_bytes = n_input_val.numel() * n_input_val.element_size()
            size_bytes = max(out_size_bytes, in_size_bytes)
            if cur_bucket_size_bytes + size_bytes > bucket_size_bytes and cur_bucket:
                # Current bucket is full, create new bucket
                if len(cur_bucket) > 1:
                    buckets.append(cur_bucket)
                cur_bucket = []
                cur_bucket_size_bytes = 0
                cur_bucket_id += 1
                cur_bucket_descendents = OrderedSet()
            cur_bucket_size_bytes += size_bytes
            cur_bucket.append(node)
            cur_bucket_descendents |= node_descendents[node]
        if len(cur_bucket) > 1:
            buckets.append(cur_bucket)
    return buckets


def bucket_all_gather_by_mb(
    gm: torch.fx.GraphModule,
    bucket_cap_mb_by_bucket_idx: Callable[[int], float],
    filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
    mode: BucketMode = "default",
) -> list[list[torch.fx.Node]]:
    """
    Identifies all all_gather nodes and groups them into buckets,
    based on size limit `bucket_cap_mb_by_bucket_idx`.

    Args:
        gm (torch.fx.GraphModule): GraphModule where to bucket all_gathers.
        bucket_cap_mb_by_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
            in megabytes by bucket idx.  The idea of `bucket_cap_mb_by_bucket_idx` is to allow
            to specify different sizes of the buckets at the start,
            as first all_gather is usually exposed.  Interface of bucket_cap_mb_by_bucket_idx
            is `bucket_cap_mb_by_bucket_idx_default` function that is default value for `bucket_cap_mb_by_bucket_idx`.
        filter_wait_node (Callable[[torch.fx.Node], bool] | None): If specified,
            only all_gather nodes with wait_node that satisfy `filter_wait_node` will be bucketed.

    Returns:
        list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of all_gather nodes.
    """

    group_key_fn = (
        _ag_group_key_multidtype if mode and "multidtype" in mode else _ag_group_key
    )

    return greedy_bucket_collective_by_mb(
        gm,
        bucket_cap_mb_by_bucket_idx,
        is_all_gather_into_tensor,
        group_key_fn,
        filter_wait_node,
    )


def bucket_reduce_scatter_by_mb(
    gm: torch.fx.GraphModule,
    bucket_cap_mb_by_bucket_idx: Callable[[int], float],
    filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
    mode: BucketMode = "default",
) -> list[list[torch.fx.Node]]:
    """
    Identifies all reduce_scatter nodes and groups them into buckets,
        based on size limit `bucket_cap_mb_by_bucket_idx`.

    Args:
        gm (torch.fx.GraphModule): GraphModule where to bucket reduce_scatters.
        bucket_cap_mb_by_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
            in megabytes by bucket idx.  The idea of `bucket_cap_mb_by_bucket_idx` is to allow
            to specify different sizes of the buckets.
        filter_wait_node (Callable[[torch.fx.Node], bool] | None): If specified,
            only reduce_scatter nodes with wait_node that satisfy `filter_wait_node` will be bucketed.

    Returns:
        list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of reduce_scatter nodes.
    """

    assert "multidtype" not in mode, (
        "reduce scatter bucketing does not support multidtype"
    )

    return greedy_bucket_collective_by_mb(
        gm,
        bucket_cap_mb_by_bucket_idx,
        is_reduce_scatter_tensor,
        _rs_group_key,
        filter_wait_node,
    )


def bucket_all_reduce_by_mb(
    gm: torch.fx.GraphModule,
    bucket_cap_mb_by_bucket_idx: Callable[[int], float],
    filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
) -> list[list[torch.fx.Node]]:
    return greedy_bucket_collective_by_mb(
        gm,
        bucket_cap_mb_by_bucket_idx,
        is_all_reduce_tensor,
        _ar_group_key,
        filter_wait_node,
    )


def bucket_all_reduce(
    gm: torch.fx.GraphModule,
    bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
    mode: str | None = None,
) -> None:
    if bucket_cap_mb_by_bucket_idx is None:
        from torch._inductor.fx_passes.bucketing import (  # pyrefly: ignore  # missing-module-attribute
            bucket_cap_mb_by_bucket_idx_default,
        )

        bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
    ar_buckets = bucket_all_reduce_by_mb(gm, bucket_cap_mb_by_bucket_idx)
    if len(ar_buckets) == 0:
        return
    for bucket in ar_buckets:
        merge_all_reduce_bucket(gm.graph, bucket, mode)


@torch.library.custom_op("bucketing::_pre_bucket_reduce_scatter", mutates_args={})
def _pre_bucket_reduce_scatter(
    rs_ins: list[torch.Tensor],
    group_size: int,
) -> torch.Tensor:
    rs_ins_flattened = [x.view(group_size, -1) for x in rs_ins]
    new_rs_in = torch.cat(rs_ins_flattened, dim=1).flatten()
    return new_rs_in


def _pre_bucket_reduce_scatter_fake(
    rs_ins: list[torch.Tensor],
    group_size: int,
) -> torch.Tensor:
    out_numel = sum(rs_in.numel() for rs_in in rs_ins)
    return torch.empty((out_numel,), device=rs_ins[0].device, dtype=rs_ins[0].dtype)


_pre_bucket_reduce_scatter.register_fake(_pre_bucket_reduce_scatter_fake)


def reduce_scatter_merge_fn_to_trace_custom_ops(
    rs_ins: list[torch.Tensor],
    group_size: int,
    group_name: str,
    reduce_op: str,
    reduce_dtype: torch.dtype,  # type: ignore[name-defined]
    device: torch.device,  # type: ignore[name-defined]
) -> list[torch.Tensor]:  # type: ignore[no-untyped-def]
    new_out_sizes = [(x.shape[0] // group_size,) + x.shape[1:] for x in rs_ins]
    new_out_numels = [x.numel() // group_size for x in rs_ins]

    new_rs_in = torch.ops.bucketing._pre_bucket_reduce_scatter(rs_ins, group_size)

    # TODO - either use torch.cat or make sure inductor foreach codegen
    # fires more reliably
    new_rs_out = torch.ops.c10d_functional.wait_tensor(
        torch.ops._c10d_functional.reduce_scatter_tensor.default(
            new_rs_in, reduce_op, group_size, group_name
        )
    )
    new_out_flat = new_rs_out.split(new_out_numels, 0)
    new_outs = [x.view(s) for x, s in zip(new_out_flat, new_out_sizes)]
    return new_outs


def reduce_scatter_merge_fn_to_trace(
    rs_ins: list[torch.Tensor],
    group_size: int,
    group_name: str,
    reduce_op: str,
    reduce_dtype: torch.dtype,  # type: ignore[name-defined]
    device: torch.device,  # type: ignore[name-defined]
) -> list[torch.Tensor]:  # type: ignore[no-untyped-def]
    rs_ins_flattened = [x.view(group_size, -1) for x in rs_ins]

    new_out_sizes = [(x.shape[0] // group_size,) + x.shape[1:] for x in rs_ins]
    new_out_numels = [x.numel() // group_size for x in rs_ins]

    new_rs_in = torch.cat(rs_ins_flattened, dim=1).flatten()

    new_rs_out = torch.ops.c10d_functional.wait_tensor(
        torch.ops._c10d_functional.reduce_scatter_tensor.default(
            new_rs_in, reduce_op, group_size, group_name
        )
    )
    new_out_flat = new_rs_out.split(new_out_numels, 0)
    new_outs = [x.view(s) for x, s in zip(new_out_flat, new_out_sizes)]
    return new_outs


def all_reduce_merge_fn_to_trace(
    ar_ins: list[torch.Tensor],
    group_name: str,
    reduce_op: str,
    reduce_dtype: torch.dtype,  # type: ignore[name-defined]
    device: torch.device,  # type: ignore[name-defined]
) -> list[torch.Tensor]:  # type: ignore[no-untyped-def]
    ar_ins_flattened = [x.view(-1) for x in ar_ins]
    new_ar_in = torch.cat(ar_ins_flattened)
    new_ar_out = torch.ops.c10d_functional.wait_tensor(
        torch.ops._c10d_functional.all_reduce.default(new_ar_in, reduce_op, group_name)
    )
    split_sizes = [x.numel() for x in ar_ins]
    new_outs_flat = new_ar_out.split(split_sizes)
    new_outs = [x.view(ar_in.shape) for x, ar_in in zip(new_outs_flat, ar_ins)]
    return new_outs


# List of all torch dtypes for serialization through custom ops
# TODO: custom ops support list[dtype] input
_ALL_DTYPES = tuple(
    [
        getattr(torch, attr)
        for attr in dir(torch)
        if isinstance(getattr(torch, attr), torch.dtype)
    ]
)


@torch.library.custom_op("bucketing::_pre_bucket_all_gather", mutates_args={})
def _pre_bucket_all_gather(
    ag_ins: list[torch.Tensor],
    group_size: int,
    group_name: str,
    dtype: torch.dtype,  # type: ignore[name-defined]
    out_dtype_ints: list[
        int
    ],  # dtype enum values, that inputs are converted to before all_gather
    rank: int,
) -> torch.Tensor:
    # Convert int indices back to torch.dtype
    out_dtypes = [_ALL_DTYPES[d] for d in out_dtype_ints]
    ins_split_sizes_bytes = [
        ag_in.numel() * out_dtype.itemsize
        for ag_in, out_dtype in zip(ag_ins, out_dtypes, strict=True)
    ]
    bucket_dtype_size_bytes = dtype.itemsize
    ins_split_sizes = [
        _bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
    ]
    ag_input_numel = sum(ins_split_sizes)
    device = ag_ins[0].device
    new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
    new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel)
    foreach_copy_dsts = torch.split(new_ag_in, ins_split_sizes)
    # View each destination slice as its output dtype, then copy
    # The copy operation handles dtype conversion from input dtype to output dtype
    foreach_copy_dsts_typed = [
        dst.view(out_dtype)
        for dst, out_dtype in zip(foreach_copy_dsts, out_dtypes, strict=True)
    ]
    ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins]
    torch._foreach_copy_(foreach_copy_dsts_typed, ag_ins_flattened)
    return new_ag_out


def _pre_bucket_all_gather_fake(
    ag_ins: list[torch.Tensor],
    group_size: int,
    group_name: str,
    dtype: torch.dtype,  # type: ignore[name-defined]
    out_dtype_ints: list[int],
    rank: int,
) -> torch.Tensor:
    out_dtypes = [_ALL_DTYPES[d] for d in out_dtype_ints]
    ins_split_sizes_bytes = [
        ag_in.numel() * out_dtype.itemsize
        for ag_in, out_dtype in zip(ag_ins, out_dtypes, strict=True)
    ]
    bucket_dtype_size_bytes = dtype.itemsize
    ins_split_sizes = [
        _bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
    ]
    ag_input_numel = sum(ins_split_sizes)
    device = ag_ins[0].device
    new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
    return new_ag_out


_pre_bucket_all_gather.register_fake(_pre_bucket_all_gather_fake)


def all_gather_merge_fn_to_trace_custom_ops(
    _ag_ins: list[torch.Tensor],
    group_size: int,
    group_name: str,
    dtype: torch.dtype,  # type: ignore[name-defined]
    out_dtypes: list[torch.dtype],  # type: ignore[name-defined]
    rank: int,
) -> list[torch.Tensor]:
    # Don't create convert_element_type ops - _pre_bucket_all_gather handles conversion
    # by viewing destination slices as output dtypes and letting copy do the conversion
    ag_ins = _ag_ins
    ins_sizes = [ag_in.shape for ag_in in ag_ins]
    ins_split_sizes_bytes = [
        ag_in.numel() * out_dtype.itemsize
        for ag_in, out_dtype in zip(ag_ins, out_dtypes)
    ]
    bucket_dtype_size_bytes = dtype.itemsize
    ins_split_sizes = [
        _bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
    ]
    ag_input_numel = sum(ins_split_sizes)

    # Convert out_dtypes to indices for custom_op
    # TODO: custom ops support list[dtype] input
    out_dtype_ints = [_ALL_DTYPES.index(dt) for dt in out_dtypes]

    new_ag_out = torch.ops.bucketing._pre_bucket_all_gather(
        ag_ins, group_size, group_name, dtype, out_dtype_ints, rank
    )
    new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel)
    wait_tensor = torch.ops.c10d_functional.wait_tensor(
        torch.ops._c10d_functional.all_gather_into_tensor_out.default(
            new_ag_in, group_size, group_name, out=new_ag_out
        )
    )
    new_ag_out_reshaped = wait_tensor.reshape(group_size, -1)
    outs_bucket_dtype = torch.split_with_sizes(
        new_ag_out_reshaped,
        ins_split_sizes,
        dim=1,
    )
    outs_reshaped = [
        o.view(out_dtype).reshape((shape[0] * group_size,) + shape[1:])
        for o, shape, out_dtype in zip(outs_bucket_dtype, ins_sizes, out_dtypes)
    ]
    return outs_reshaped


def all_gather_merge_fn_to_trace(
    ag_ins: list[torch.Tensor],
    group_size: int,
    group_name: str,
    dtype: torch.dtype,  # type: ignore[name-defined]
    out_dtypes: list[torch.dtype],  # type: ignore[name-defined]
    rank: int,
) -> list[torch.Tensor]:
    ins_sizes = [ag_in.shape for ag_in in ag_ins]
    ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
    ag_input_numel = sum(ins_split_sizes)
    device = ag_ins[0].device
    new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
    new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel)
    foreach_copy_dsts = torch.split(new_ag_in, ins_split_sizes)
    ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins]
    torch._foreach_copy_(foreach_copy_dsts, ag_ins_flattened)
    wait_tensor = torch.ops.c10d_functional.wait_tensor(
        torch.ops._c10d_functional.all_gather_into_tensor_out.default(
            new_ag_in, group_size, group_name, out=new_ag_out
        )
    )
    new_ag_out_reshaped = wait_tensor.reshape(group_size, -1)
    outs = torch.split_with_sizes(
        new_ag_out_reshaped,
        ins_split_sizes,
        dim=1,
    )
    outs_reshaped = [
        o.reshape((shape[0] * group_size,) + shape[1:])
        for o, shape in zip(outs, ins_sizes)
    ]
    return outs_reshaped


def all_gather_merge_fn_to_trace_functional(
    ag_ins: list[torch.Tensor],
    group_size: int,
    group_name: str,
    dtype: torch.dtype,  # type: ignore[name-defined]
    out_dtypes: list[torch.dtype],  # type: ignore[name-defined]
    rank: int,
    use_fsdp_ag_copy_in: bool = False,
) -> list[torch.Tensor]:
    # Implementation that is functional in graph,
    # but uses custom op torch.ops.fsdp.all_gather_copy_in.
    ins_sizes = [ag_in.shape for ag_in in ag_ins]
    ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
    ag_input_numel = sum(ins_split_sizes)
    device = ag_ins[0].device
    new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
    ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins]
    if use_fsdp_ag_copy_in:
        new_ag_in, new_ag_out = torch.ops.fsdp.all_gather_copy_in(
            ag_ins_flattened, new_ag_out, ins_split_sizes, ag_input_numel, rank
        )
    else:
        new_ag_in = torch.cat(ag_ins_flattened, dim=0)
    wait_tensor = torch.ops.c10d_functional.wait_tensor(
        torch.ops._c10d_functional.all_gather_into_tensor_out.default(
            new_ag_in, group_size, group_name, out=new_ag_out
        )
    )
    new_ag_out_reshaped = wait_tensor.reshape(group_size, -1)
    outs = torch.split_with_sizes(
        new_ag_out_reshaped,
        ins_split_sizes,
        dim=1,
    )
    outs_reshaped = [
        o.reshape((shape[0] * group_size,) + shape[1:])
        for o, shape in zip(outs, ins_sizes)
    ]
    return outs_reshaped


def _trace(fn, inps) -> torch.fx.GraphModule:  # type: ignore[no-untyped-def]
    with dynamo_timed("fx.bucketing._trace", log_pt2_compile_event=True):
        fake_mode = detect_fake_mode(inps)
        assert fake_mode is not None
        with fake_mode, enable_python_dispatcher():
            out = make_fx(fn)(*inps)
            for node in out.graph.find_nodes(
                op="call_function", target=torch.ops.aten.detach.default
            ):
                node.replace_all_uses_with(node.args[0])
                out.graph.erase_node(node)
            return out


def _insert_fn_trace_before_node(  # type: ignore[no-untyped-def]
    g: torch.fx.Graph,
    fn_to_trace,
    inps,
    insert_before_node: torch.fx.Node,
    g_fn_inps: list[torch.fx.Node],
    g_fn_outs: list[torch.fx.Node],
) -> tuple[dict[torch.fx.Node, torch.fx.Node], list[torch.fx.Node]]:  # type: ignore[no-untyped-def]
    """
    Helper function that traces :attr:`fn_to_trace` with inputs
    :attr:`inps`.
    The result function graph will be inserted before :attr:`insert_before_node`,
    using :attr:`g_fn_inps` nodes of original graph as inputs of function graph,
    function graph outputs will replace :attr:`g_fn_outs` in original graph.

    Returns:
        (replacements, new_nodes): Dictionary mapping old to new nodes, and list of all newly inserted nodes
    """
    with dynamo_timed(
        "fx.bucketing._insert_fn_trace_before_node", log_pt2_compile_event=True
    ):
        fn_gm = _trace(
            fn_to_trace,
            inps,
        )
        fn_g = fn_gm.graph
        fn_g_ins = fn_g.find_nodes(op="placeholder")
        env = {fn_g_ins[idx]: g_fn_inps[idx] for idx in range(len(g_fn_inps))}
        g_fn_new_outs: list[torch.fx.Node] = []
        new_nodes: list[torch.fx.Node] = []  # Track all newly inserted nodes

        with g.inserting_before(insert_before_node):
            for _n in fn_g.nodes:
                if _n.op == "placeholder":
                    continue
                _new_n = g.node_copy(_n, lambda x: env[x])
                env[_n] = _new_n
                if _n.op == "output":
                    g_fn_new_outs = _new_n.args[0]  # type: ignore[assignment]
                    g.erase_node(_new_n)
                else:
                    new_nodes.append(_new_n)  # Track non-output nodes

        replacements = {  # noqa: C416
            orig_out: new_out for orig_out, new_out in zip(g_fn_outs, g_fn_new_outs)
        }
        for orig_out, new_out in zip(g_fn_outs, g_fn_new_outs):
            orig_out.replace_all_uses_with(new_out)

        return replacements, new_nodes


def has_mergeable_all_gather_convert_dtype(n: torch.fx.Node) -> bool:
    node_in = n.args[0]
    return (
        is_all_gather_into_tensor(n)
        and isinstance(node_in, torch.fx.Node)
        and node_in.op == "call_function"
        and (
            node_in.target is torch.ops.prims.convert_element_type.default
            or node_in.target is torch.ops.aten._to_copy.default
        )
        and len(node_in.users) == 1
    )


def process_collective_bucket(
    g: torch.fx.Graph,
    bucket_nodes: list[torch.fx.Node],
    fn_to_trace: Callable[..., list[torch.Tensor]],
    trace_args_fn: Callable[[list[torch.fx.Node]], tuple[Any, ...]],
    insert_before: torch.fx.Node | None = None,
    wait_insertion_point: torch.fx.Node | None = None,
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
    """
    Process a single bucket of collective operation nodes with flexible insertion control.

    Args:
        g: The graph to modify
        bucket_nodes: Nodes in the current bucket to process
        fn_to_trace: Function to trace and insert
        trace_args_fn: Function to create trace arguments from inputs
        insert_before: Where to insert the traced function (default: after last bucket node)
        wait_insertion_point: If provided, move all nodes from wait() onwards to before this node

    Returns:
        new_nodes: List of all newly inserted nodes
        replacements: Dictionary mapping old wait nodes to new output nodes
    """
    # Collect inputs and waits from current bucket
    bucket_ins: list[torch.fx.Node] = []
    bucket_waits: list[torch.fx.Node] = []
    ag_node_to_pre_nodes: dict[torch.fx.Node, list[torch.fx.Node]] = defaultdict(list)

    for n in bucket_nodes:
        assert len(n.users) == 1, f"Expected single user for {n}, got {n.users}"
        wait_n = next(iter(n.users))

        # Handle convert_element_type operations (for all_gather)
        node_in = n.args[0]
        if has_mergeable_all_gather_convert_dtype(n):
            ag_node_to_pre_nodes[n].append(node_in)
            node_in = node_in.args[0]

        assert isinstance(node_in, torch.fx.Node)  # Ensure node_in is a Node
        bucket_ins.append(node_in)
        bucket_waits.append(wait_n)

    # Create trace arguments
    trace_args = trace_args_fn(bucket_ins)

    # Determine insertion point
    if insert_before is None:
        insert_before = bucket_nodes[-1].next

    # Insert traced function and get replacements + new nodes
    replacements, new_nodes = _insert_fn_trace_before_node(
        g,
        fn_to_trace,
        trace_args,
        insert_before,
        bucket_ins,
        bucket_waits,
    )

    # If requested, move wait nodes and everything after to specified location
    if wait_insertion_point is not None:
        # Find the first wait node in new_nodes
        wait_start_idx = None
        for i, node in enumerate(new_nodes):
            if is_wait_tensor(node):
                wait_start_idx = i
                break

        # Move all nodes from wait onwards (including the wait)
        if wait_start_idx is not None:
            nodes_to_move = new_nodes[wait_start_idx:]
            for node in nodes_to_move:
                wait_insertion_point.prepend(node)

    # Preserve metadata from original collective nodes to new bucketed nodes
    if bucket_nodes:
        overlap_log.debug(
            "Bucketing nodes: %s, New nodes: %s",
            ",".join([n.name for n in bucket_nodes]),
            ",".join([n.name for n in new_nodes]),
        )
    _populate_node_meta(bucket_nodes, new_nodes)

    # Erase old nodes
    for node, wait_n in zip(bucket_nodes, bucket_waits):
        g.erase_node(wait_n)
        g.erase_node(node)
        # Erase any convert_element_type nodes we tracked
        for pre_node in reversed(ag_node_to_pre_nodes[node]):
            g.erase_node(pre_node)

    return new_nodes, replacements


def merge_reduce_scatter_bucket(
    g: torch.fx.Graph,
    rs_nodes: list[torch.fx.Node],
    mode: BucketMode = "default",
    insert_before: torch.fx.Node | None = None,
    wait_insertion_point: torch.fx.Node | None = None,
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
    # Validate bucket consistency
    rs0 = rs_nodes[0]
    rs0_val = rs0.meta["val"]
    _, reduce_op, group_size, group_name = rs0.args
    reduce_dtype = rs0_val.dtype
    device = rs0_val.device

    for n in rs_nodes:
        rs_val = n.meta["val"]
        assert (
            n.args[1] == reduce_op
            and n.args[2] == group_size
            and n.args[3] == group_name
            and rs_val.device == device
            and rs_val.dtype == reduce_dtype
        )

    # Choose merge function based on mode
    rs_merge_fn = reduce_scatter_merge_fn_to_trace
    if mode and "custom_ops" in mode:
        rs_merge_fn = reduce_scatter_merge_fn_to_trace_custom_ops

    # Process bucket with lazy input collection
    def create_trace_args(bucket_ins: list[torch.fx.Node]) -> tuple[Any, ...]:
        return (
            pytree.tree_map(lambda node: node.meta["val"], bucket_ins),
            group_size,
            group_name,
            reduce_op,
            reduce_dtype,
            device,
        )

    return process_collective_bucket(
        g,
        rs_nodes,
        rs_merge_fn,
        create_trace_args,
        insert_before=insert_before,
        wait_insertion_point=wait_insertion_point,
    )


def merge_all_reduce_bucket(
    g: torch.fx.Graph,
    ar_nodes: list[torch.fx.Node],
    mode: str | None = None,
    insert_before: torch.fx.Node | None = None,
    wait_insertion_point: torch.fx.Node | None = None,
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
    ar0 = ar_nodes[0]
    ar0_val = ar0.meta["val"]
    _, reduce_op, group_name = ar0.args
    reduce_dtype = ar0_val.dtype
    device = ar0_val.device

    for n in ar_nodes:
        ar_val = n.meta["val"]
        assert (
            n.args[1] == reduce_op
            and n.args[2] == group_name
            and ar_val.device == device
            and ar_val.dtype == reduce_dtype
        )

    ar_merge_fn = all_reduce_merge_fn_to_trace

    def create_trace_args(bucket_ins: list[torch.fx.Node]) -> tuple[Any, ...]:
        return (
            pytree.tree_map(lambda node: node.meta["val"], bucket_ins),
            group_name,
            reduce_op,
            reduce_dtype,
            device,
        )

    return process_collective_bucket(
        g,
        ar_nodes,
        ar_merge_fn,
        create_trace_args,
        insert_before=insert_before,
        wait_insertion_point=wait_insertion_point,
    )


def merge_all_gather_bucket(
    g: torch.fx.Graph,
    ag_nodes: list[torch.fx.Node],
    mode: BucketMode = "default",
    insert_before: torch.fx.Node | None = None,
    wait_insertion_point: torch.fx.Node | None = None,
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
    from torch.distributed.distributed_c10d import _resolve_process_group

    ag0 = ag_nodes[0]
    _, group_size, group_name = ag0.args
    assert isinstance(group_name, str)
    _ag_dtypes: list[torch.dtype] = []  # type: ignore[name-defined]

    for n in ag_nodes:
        assert n.args[1] == group_size and n.args[2] == group_name
        _ag_dtypes.append(n.meta["val"].dtype)

    bucket_dtype = pick_bucket_dtype(_ag_dtypes)

    # Choose merge function based on mode
    ag_merge_fn = all_gather_merge_fn_to_trace
    if mode is not None and "custom_ops" in mode:
        ag_merge_fn = all_gather_merge_fn_to_trace_custom_ops  # type: ignore[assignment]

    # Process bucket with lazy input collection
    rank: int = dist.get_rank(_resolve_process_group(group_name))

    def create_trace_args(bucket_ins: list[torch.fx.Node]) -> tuple[Any, ...]:
        return (
            pytree.tree_map(lambda node: node.meta["val"], bucket_ins),
            group_size,
            group_name,
            bucket_dtype,
            _ag_dtypes,
            rank,
        )

    return process_collective_bucket(
        g,
        ag_nodes,
        ag_merge_fn,
        create_trace_args,
        wait_insertion_point=wait_insertion_point,
    )


def merge_reduce_scatter(
    gm: torch.fx.GraphModule,
    rs_buckets: list[list[torch.fx.Node]],
    mode: BucketMode = "default",
) -> None:
    """
    Merges specified buckets of reduce_scatter to joint reduce_scatter.
    """
    with dynamo_timed("fx.bucketing.merge_reduce_scatter", log_pt2_compile_event=True):
        trace_structured(
            "artifact",
            metadata_fn=lambda: {
                "name": "fx_bucketing_passes_reduce_scatter_buckets",
                "encoding": "string",
            },
            payload_fn=lambda: str(rs_buckets),
        )

        g = gm.graph

        for rs_nodes in rs_buckets:
            merge_reduce_scatter_bucket(g, rs_nodes, mode)


def merge_all_gather(
    gm: torch.fx.GraphModule,
    ag_buckets: list[list[torch.fx.Node]],
    mode: BucketMode = "default",
) -> None:
    """
    Merges specified buckets of all_gather to joint all_gather.
    """
    with dynamo_timed("fx.bucketing.merge_all_gather", log_pt2_compile_event=True):
        trace_structured(
            "artifact",
            metadata_fn=lambda: {
                "name": "fx_bucketing_passes_all_gather_buckets",
                "encoding": "string",
            },
            payload_fn=lambda: str(ag_buckets),
        )

        g = gm.graph

        for ag_nodes in ag_buckets:
            merge_all_gather_bucket(g, ag_nodes, mode)
