Utility functions

Assorted utility functions for sparse and irregularly-structured tensors.


Vector lexicographical sorting

lexsort_nd(tensor, vector_dim, sort_dim, descending=False, stable=False, force_robust=False)

Sorts a tensor of vectors in lexicographic order.

Given a tensor of vectors, performs a sort that orders the vectors in lexicographic order. The vectors are defined along the vector_dim dimension, and sorted along the sort_dim dimension. If force_robust is False, then a fast lexicographic sort based on projecting the vectors to an order-preserving 1D basis is used if possible, falling back to a "robust" (true) multi-pass lexicographic sort if the input vectors cannot be losslessly compressed to 1D. If force_robust is True, the robust sort is always used. Both integer and floating-point tensors are supported.

Parameters:
  • tensor (Tensor) –

    Tensor to be sorted.

  • vector_dim (int) –

    Index along which vectors are defined.

  • sort_dim (int) –

    Index along which to sort.

  • descending (bool, default: False ) –

    If True, vectors are sorted in descending order. Default: False.

  • stable (bool, default: False ) –

    If True, stable sort is always used (order of equivalent values is kept). If False, unstable sorts are used when possible.

  • force_robust (bool, default: False ) –

    If True, always use the "true" iterative lexsort. This requires tensor.shape[vector_dim] sorts instead of 1 sort, but is more reproducible.

Returns:
  • tuple[Tensor, Tensor]

    tuple[Tensor, Tensor]: - Tensor: Sorted tensor. - Tensor: Sort indices.

Notes
  • The relationship between the sorted tensor and the sort indices is: sort_indices_exp = sort_indices.unsqueeze(vector_dim).expand_as(tensor) sorted_tensor = tensor.gather(sort_dim, sort_indices_exp).
Source code in pytorch_sparse_utils/utils/lexsort_nd.py
@torch.jit.script
def lexsort_nd(
    tensor: Tensor,
    vector_dim: int,
    sort_dim: int,
    descending: bool = False,
    stable: bool = False,
    force_robust: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Sorts a tensor of vectors in lexicographic order.

    Given a tensor of vectors, performs a sort that orders the vectors
    in lexicographic order. The vectors are defined along the `vector_dim` dimension,
    and sorted along the `sort_dim` dimension.
    If `force_robust` is False, then a fast lexicographic sort based on projecting the
    vectors to an order-preserving 1D basis is used if possible, falling back to a
    "robust" (true) multi-pass lexicographic sort if the input vectors cannot be
    losslessly compressed to 1D. If `force_robust` is True, the robust sort is always
    used.
    Both integer and floating-point tensors are supported.

    Args:
        tensor (Tensor): Tensor to be sorted.
        vector_dim (int): Index along which vectors are defined.
        sort_dim (int): Index along which to sort.
        descending (bool): If True, vectors are sorted in descending order. Default: False.
        stable (bool): If True, stable sort is always used (order of equivalent values is kept).
            If False, unstable sorts are used when possible.
        force_robust (bool): If True, always use the "true" iterative lexsort. This requires
            tensor.shape[vector_dim] sorts instead of 1 sort, but is more reproducible.

    Returns:
        tuple[Tensor, Tensor]:
            - Tensor: Sorted tensor.
            - Tensor: Sort indices.

    Notes:
        - The relationship between the sorted tensor and the sort indices is:
            sort_indices_exp = sort_indices.unsqueeze(vector_dim).expand_as(tensor)
            sorted_tensor = tensor.gather(sort_dim, sort_indices_exp).
    """
    # Normalize dims
    ndim = tensor.ndim
    vector_dim = vector_dim if vector_dim >= 0 else vector_dim + ndim
    sort_dim = sort_dim if sort_dim >= 0 else sort_dim + ndim

    # Input checks
    if vector_dim < 0 or vector_dim >= ndim:
        raise ValueError(
            f"Normalized key_dim {vector_dim} is out of bounds for tensor with {ndim} "
            "dimensions."
        )
    if sort_dim < 0 or sort_dim >= ndim:
        raise ValueError(
            f"Normalized sort_dim {sort_dim} is out of bounds for tensor with {ndim} "
            "dimensions."
        )
    if sort_dim == vector_dim:
        raise ValueError(
            f"Expected vector_dim and sort_dim to be different, but got both "
            f"= {sort_dim}"
        )
    if tensor.isnan().any():
        raise ValueError("Tensor has nan values.")
    if tensor.isinf().any():
        raise ValueError("Tensor has infinite values.")

    # Get vector length
    vector_len = tensor.shape[vector_dim]

    # Handle edge cases
    if tensor.numel() == 0:
        indices_shape = list(tensor.shape)
        indices_shape.pop(vector_dim)
        return tensor, torch.zeros(
            indices_shape, device=tensor.device, dtype=torch.long
        )
    if tensor.size(sort_dim) == 1:
        indices_shape = list(tensor.shape)
        indices_shape.pop(vector_dim)
        return tensor, torch.zeros(
            indices_shape, device=tensor.device, dtype=torch.long
        )
    if vector_len == 1:  # Just do regular sort
        tensor, sort_indices = torch.sort(
            tensor,
            dim=sort_dim,
            descending=descending,
            stable=stable,
        )
        sort_indices = sort_indices.squeeze(vector_dim)
        return tensor, sort_indices

    # Move vector_dim to last position for projection reduction
    # and sort_dim to first position for faster sorting
    tensor_permuted, perm = _permute_dims(tensor, vector_dim, sort_dim)
    tensor_permuted = tensor_permuted.contiguous()

    # List of integer types
    _INT_TYPES = (
        torch.int8,
        torch.int16,
        torch.int32,
        torch.int64,
        torch.uint8,
    )

    # Pick appropriate sorting subroutine
    if force_robust:
        sorted_tensor_permuted, indices = _lexsort_nd_robust(
            tensor_permuted, descending=descending
        )
    elif torch.is_floating_point(tensor_permuted):
        sorted_tensor_permuted, indices = _lexsort_nd_float(
            tensor_permuted, descending, stable
        )
    elif tensor_permuted.dtype in _INT_TYPES:
        indices = _lexsort_nd_int(tensor_permuted, descending, stable).sort_indices
        sorted_tensor_permuted = None
    else:
        raise ValueError(f"Unsupported tensor dtype {tensor.dtype}")

    # Gather from the original tensor using the sort indices
    indices_unsq = indices.unsqueeze(-1)  # add singleton dim at permuted vector dim
    if sorted_tensor_permuted is None:  # get sorted tensor if not returned already
        sorted_tensor_permuted = torch.gather(
            tensor_permuted, dim=0, index=indices_unsq.expand_as(tensor_permuted)
        )

    # Permute tensor and indices back to original dimension order
    inverse_perm = [0] * tensor.ndim
    for i, p in enumerate(perm):
        inverse_perm[p] = i
    sorted_tensor = sorted_tensor_permuted.permute(inverse_perm)

    sort_indices = indices_unsq.permute(inverse_perm).squeeze(vector_dim)

    return sorted_tensor, sort_indices

Concatenated-batch topk

BatchTopK

Bases: NamedTuple

Source code in pytorch_sparse_utils/utils/batch_topk.py
class BatchTopK(NamedTuple):
    indices: Tensor
    offsets: Tensor
    values: Optional[Tensor] = None

batch_topk(tensor, batch_offsets, k, dim=0, largest=True, sorted=True, return_values=False)

Performs top-k operation on a batch-concatenated tensor with variable sequence lengths.

This function handles both uniform-length sequences (where a more efficient batch operation is used) and variable-length sequences (where per-batch processing occurs). The function returns indices adjusted to the original tensor's indexing space along with offsets to identify each batch's results.

Parameters:
  • tensor (Tensor) –

    A batch-concatenated tensor of shape (total_length, d1, d2, ...) where total_length is the sum of all sequence lengths.

  • batch_offsets (Tensor) –

    A 1D tensor of indices indicating where each sequence begins in the batch-concatenated tensor. Should be of shape (batch_size + 1,) with the last element being the total length.

  • k (Union[Tensor, int]) –

    Number of top elements to select. Can be an integer for the same k across all batches, or a tensor or list for different k per batch. Will be clamped to each sequence's length if k > sequence_length.

  • dim (int, default: 0 ) –

    Dimension along which to perform the top-k operation for each concatenated subsequence. Default: 0 (sequence dimension).

  • largest (bool, default: True ) –

    If True, returns the indices of the largest elements. If False, returns those of the smallest elements. Default: True.

  • sorted (bool, default: True ) –

    If True, always returns the elements in sorted order. For technical reasons, the returned elements may be sorted in some cases even when False. Default: True.

  • return_values (bool, default: False ) –

    If True, the output namedtuple will include the topk values in addition to the indices and offsets. Default: False.

Returns:
  • BatchTopK( BatchTopK ) –

    A namedtuple containing: - indices (Tensor): 1-D long tensor of the top-k indices in the original concatenated tensor space. - offsets (Tensor): 1-D long tensor of offsets with shape (batch_size + 1) indicating where each batch's results begin in the topk_indices tensor. Like the input batch_offsets, BatchTopK.offsets[-1] == len(BatchTopK.indices). - values (Optional[Tensor]): 1-D tensor of the same length as indices with the actual topk values. Returned only if return_values if True, otherwise is None.

Usage Notes
  1. Slice per sequence The results for the b-th input sequence are obtained with

    idx_b   = out.indices[out.offsets[b] : out.offsets[b + 1]]
    if out.values is not None:
        val_b = out.values[out.offsets[b] : out.offsets[b + 1]]
    

    where out = batch_topk(...).

  2. Expected slice length Let seq_len_b = batch_offsets[b + 1] - batch_offsets[b] k_b = min(k, seq_len_b) # or k[b] if k is per-sample

    length(idx_b) is

    k_b * prod(tensor.size(i) for i not in {0, dim})      if dim == 0
    k_b * seq_len_b * prod(tensor.size(i) for i not in {0, dim})   otherwise
    

    This is exactly the number of elements returned by a regular torch.topk call on that subsequence.

  3. Reshaping to the usual top-k shape After selecting subseq_b = tensor[batch_offsets[b] : batch_offsets[b + 1]] you can turn the flat index slice back into the layout produced by torch.topk:

    out_shape = list(subseq_b.shape)
    out_shape[dim] = k_b
    idx_b = idx_b.view(*out_shape)
    

    The same out_shape works for val_b when it is present.

  4. Recovering the values when they were not returned If return_values=False, the values can still be gathered later.

    • When dim != 0 the indices are local to every subsequence, therefore you can gather directly:

      val_b = torch.take_along_dim(subseq_b, idx_b, dim)

    • When dim == 0 the indices are expressed in the global (concatenated-tensor) space. Either use them on the original concatenated tensor,

      val_b = torch.take_along_dim(tensor, idx_b, dim)

    or convert them back to local coordinates first:

       idx_b_local = idx_b - batch_offsets[b]
       val_b       = torch.take_along_dim(subseq_b, idx_b_local, dim)
    
Source code in pytorch_sparse_utils/utils/batch_topk.py
@torch.jit.script
def batch_topk(
    tensor: Tensor,
    batch_offsets: Tensor,
    k: Union[Tensor, list[int], int],
    dim: int = 0,
    largest: bool = True,
    sorted: bool = True,
    return_values: bool = False,
) -> BatchTopK:
    """
    Performs top-k operation on a batch-concatenated tensor with variable sequence lengths.

    This function handles both uniform-length sequences (where a more efficient batch
    operation is used) and variable-length sequences (where per-batch processing occurs).
    The function returns indices adjusted to the original tensor's indexing space
    along with offsets to identify each batch's results.

    Args:
        tensor (Tensor): A batch-concatenated tensor of shape (total_length, d1, d2, ...)
            where total_length is the sum of all sequence lengths.
        batch_offsets (Tensor): A 1D tensor of indices indicating where each sequence
            begins in the batch-concatenated tensor. Should be of shape (batch_size + 1,)
            with the last element being the total length.
        k (Union[Tensor, int]): Number of top elements to select. Can be an integer
            for the same k across all batches, or a tensor or list for different k per
            batch. Will be clamped to each sequence's length if k > sequence_length.
        dim (int, optional): Dimension along which to perform the top-k operation for
            each concatenated subsequence. Default: 0 (sequence dimension).
        largest (bool, optional): If True, returns the indices of the largest elements.
            If False, returns those of the smallest elements. Default: True.
        sorted (bool, optional): If True, always returns the elements in sorted order.
            For technical reasons, the returned elements may be sorted in some cases
            even when False. Default: True.
        return_values (bool, optional): If True, the output namedtuple will include the
            topk values in addition to the indices and offsets. Default: False.

    Returns:
        BatchTopK: A namedtuple containing:
            - indices (Tensor): 1-D long tensor of the top-k indices in the original
                concatenated tensor space.
            - offsets (Tensor): 1-D long tensor of offsets with shape (batch_size + 1)
                indicating where each batch's results begin in the topk_indices tensor.
                Like the input batch_offsets,
                BatchTopK.offsets[-1] == len(BatchTopK.indices).
            - values (Optional[Tensor]): 1-D tensor of the same length as `indices` with
                the actual topk values. Returned only if return_values if True, otherwise
                is None.

    Usage Notes:
        1.  Slice per sequence
            The results for the b-th input sequence are obtained with

                idx_b   = out.indices[out.offsets[b] : out.offsets[b + 1]]
                if out.values is not None:
                    val_b = out.values[out.offsets[b] : out.offsets[b + 1]]

            where `out = batch_topk(...)`.

        2.  Expected slice length
            Let
                seq_len_b = batch_offsets[b + 1] - batch_offsets[b]
                k_b       = min(k, seq_len_b)       # or k[b] if k is per-sample

            length(idx_b) is

                k_b * prod(tensor.size(i) for i not in {0, dim})      if dim == 0
                k_b * seq_len_b * prod(tensor.size(i) for i not in {0, dim})   otherwise

            This is exactly the number of elements returned by a
            regular torch.topk call on that subsequence.

        3.  Reshaping to the usual top-k shape
            After selecting `subseq_b = tensor[batch_offsets[b] : batch_offsets[b + 1]]`
            you can turn the flat index slice back into the layout produced by
            torch.topk:

                out_shape = list(subseq_b.shape)
                out_shape[dim] = k_b
                idx_b = idx_b.view(*out_shape)

            The same `out_shape` works for `val_b` when it is present.

        4.  Recovering the values when they were not returned
            If `return_values=False`, the values can still be gathered later.

            -  When ``dim != 0`` the indices are **local** to every subsequence,
               therefore you can gather directly:

                   val_b = torch.take_along_dim(subseq_b, idx_b, dim)

            -  When ``dim == 0`` the indices are expressed in the **global**
               (concatenated-tensor) space.  Either use them on the original
               concatenated tensor,

                   val_b = torch.take_along_dim(tensor, idx_b, dim)

               or convert them back to local coordinates first:

                   idx_b_local = idx_b - batch_offsets[b]
                   val_b       = torch.take_along_dim(subseq_b, idx_b_local, dim)
    """
    seq_lens = batch_offsets_to_seq_lengths(batch_offsets)
    assert isinstance(seq_lens, Tensor)
    batch_size = seq_lens.numel()

    # Normalize k
    k = _normalize_k(k, batch_size, tensor.device)
    assert isinstance(k, Tensor)
    if torch.any(k < 0):
        raise ValueError(f"Expected nonnegative value for `k`, got {k}")

    # Normalize dim if negative
    dim = dim if dim >= 0 else dim + tensor.ndim
    if dim < 0 or dim >= tensor.dim():
        raise ValueError(
            "Normalized dimension "
            f"{dim} is out of bounds for tensor with {tensor.ndim} dimensions"
        )

    # Early return for empty tensor
    if batch_size == 0 or tensor.numel() == 0:
        topk_indices = torch.empty(0, device=tensor.device, dtype=torch.long)
        topk_offsets = torch.zeros(1, device=tensor.device, dtype=torch.long)
        topk_values = tensor[0:0].flatten() if return_values else None  # keep gradients
        return BatchTopK(topk_indices, topk_offsets, topk_values)

    # Find product of dims besides seq and topk dims
    prod_extra_dims = 1
    for i, s in enumerate(tensor.shape):
        if i not in (0, dim):
            prod_extra_dims *= s

    same_len = torch.all(seq_lens == seq_lens[0])
    same_k = torch.all(k == k[0])

    if same_len:  # sequences same length, batch for efficiency
        # Clamp k to seq length
        seq_len_int = int(seq_lens[0].item())
        if dim == 0:
            k_max = torch.min(k.amax(), seq_lens[0])
        else:
            k_max = k.amax().clamp_max(tensor.shape[dim])
        k = k.clamp_max(k_max)
        k_max_int = int(k_max.item())

        # Compute per-batch result size
        if dim == 0:
            out_sizes = k * prod_extra_dims
        else:
            out_sizes = k * seq_len_int * prod_extra_dims
        topk_offsets = seq_lengths_to_batch_offsets(out_sizes)
        assert isinstance(topk_offsets, Tensor)

        if k_max_int == 0:
            topk_indices = tensor.new_empty(0, dtype=torch.long)
            topk_values = tensor[0:0].flatten() if return_values else None
            return BatchTopK(topk_indices, topk_offsets, topk_values)

        # reshape to [bsz, seq_len, ...]
        batch_shape = (batch_size, seq_len_int) + tensor.shape[1:]
        topk_dim = dim + 1  # account for new leading batch dim

        values_all, indices_all = tensor.reshape(batch_shape).topk(
            k_max_int, topk_dim, largest=largest, sorted=True
        )  # Need to be sorted to be able to select first k for each subseq

        # If topk is along sequence length, need to add offsets to indices
        # to globalize them
        if dim == 0:
            unsqueeze_shape = (batch_size,) + (1,) * (indices_all.ndim - 1)
            indices_all = indices_all + batch_offsets[:-1].view(unsqueeze_shape)

        if same_k:
            topk_indices = indices_all.flatten()
            topk_values = values_all.flatten() if return_values else None
            return BatchTopK(topk_indices, topk_offsets, topk_values)

        # not all same k: slice into the topk output for each batch
        total_len = int(topk_offsets[-1])
        topk_indices = tensor.new_empty(total_len, dtype=torch.long)
        topk_values = tensor.new_empty(total_len) if return_values else None

        for b in range(batch_size):
            k_b = int(k[b])
            if k_b == 0:
                continue

            # slice into the already-computed result
            start, end = int(topk_offsets[b]), int(topk_offsets[b + 1])
            topk_indices[start:end] = indices_all[b].narrow(dim, 0, k_b).flatten()
            if return_values:
                assert topk_values is not None
                topk_values[start:end] = values_all[b].narrow(dim, 0, k_b).flatten()

        return BatchTopK(topk_indices, topk_offsets, topk_values)

    ##########
    # Slow path
    # -- Sequences different length, run topk for each --

    if dim == 0:
        batch_seq_ks = seq_lens.clamp_max(k)
        batch_out_sizes = batch_seq_ks * prod_extra_dims
    else:
        dim_size = tensor.size(dim)
        batch_seq_ks = k.clamp_max(dim_size)
        batch_out_sizes = batch_seq_ks * seq_lens * prod_extra_dims
    topk_offsets = seq_lengths_to_batch_offsets(batch_out_sizes)
    assert isinstance(topk_offsets, Tensor)

    # Allocate result tensor(s)
    topk_indices = torch.empty(
        int(topk_offsets[-1].item()), dtype=torch.long, device=tensor.device
    )
    topk_values: Optional[Tensor] = None
    if return_values:
        topk_values = tensor.new_empty(int(topk_offsets[-1]))

    # Allocate a "scratch" buffer to hold the topk values outputs
    max_seq_len = int(seq_lens.max().item())
    max_k = int(batch_seq_ks.max().item()) if dim == 0 else int(k.max().item())

    scratch_shape = list(tensor.shape)
    scratch_shape[dim] = max_k
    if dim != 0:
        scratch_shape[0] = max_seq_len
    scratch_values = tensor.new_empty(scratch_shape)

    # per-batch topk
    for b, k_b in enumerate(batch_seq_ks):
        k_b = int(k_b)
        batch_start, batch_end = int(batch_offsets[b]), int(batch_offsets[b + 1])
        slice_start, slice_end = int(topk_offsets[b]), int(topk_offsets[b + 1])

        # slice of the big concatted sequence tensor that represents this subsequence
        subseq_b = tensor[batch_start:batch_end]

        # Set up the view into the output topk tensor
        subseq_shape = list(subseq_b.shape)
        subseq_shape[dim] = k_b
        topk_inds_subseq_view = topk_indices[slice_start:slice_end].view(subseq_shape)

        # Set up view into reusable scratch holder for topk values output
        if dim == 0:
            val_buffer = scratch_values[:k_b]
        else:
            val_buffer = scratch_values.narrow(0, 0, subseq_b.size(0))
            val_buffer = val_buffer.narrow(dim, 0, k_b)

        _topk_out(
            subseq_b.detach(),
            k_b,
            dim=dim,
            largest=largest,
            sorted=sorted,
            out_values=val_buffer,
            out_indices=topk_inds_subseq_view,
        )
        if return_values:
            # clone topk inds to save unmodified tensor for take_along_dim backward
            values = torch.take_along_dim(subseq_b, topk_inds_subseq_view.clone(), dim)
            assert topk_values is not None
            topk_values[slice_start:slice_end] = values.reshape(-1)
        if dim == 0:
            topk_inds_subseq_view.add_(batch_start)

    return BatchTopK(topk_indices, topk_offsets, topk_values)

unpack_batch_topk(result, batch_offsets, original_shape, dim=0)

Re-shape and localize the flattened indices/values inside BatchTopK object.

Parameters:
  • result (BatchTopK) –

    The object returned by :pyfunc:batch_topk.

  • batch_offsets (Tensor) –

    Same offsets that were passed to batch_topk.

  • original_shape ([list[int]) –

    tensor.shape of the concatenated tensor that was given to batch_topk.

  • dim (int, default: 0 ) –

    Dimension along which top-k was computed (same value that was given to batch_topk).

Returns:
  • indices_per_batch( list[Tensor] ) –

    List containing one tensor per input sequence. Each tensor is the same shape as returned by a call to torch.topk(subseq, ...) for that subsequence.

  • values_per_batch( Optional[list[Tensor]] ) –

    List containing one tensor per input sequence. Like indices_per_batch, each tensor will be the same shape as returned by a standalone call torch.topk(subseq, ...). If batch_topk was originally called with return_values=False, then values_per_batch will be None.

Source code in pytorch_sparse_utils/utils/batch_topk.py
@torch.jit.script
def unpack_batch_topk(
    result: BatchTopK,
    batch_offsets: Tensor,
    original_shape: list[int],
    dim: int = 0,
) -> tuple[list[Tensor], Optional[list[Tensor]]]:
    """
    Re-shape and localize the flattened `indices`/`values` inside `BatchTopK` object.

    Args:
        result (BatchTopK): The object returned by :pyfunc:`batch_topk`.
        batch_offsets (Tensor): Same offsets that were passed to `batch_topk`.
        original_shape ([list[int]): ``tensor.shape`` of the concatenated tensor that
            was given to `batch_topk`.
        dim (int): Dimension along which top-k was computed (same value that was given
            to `batch_topk`).

    Returns:
        indices_per_batch (list[Tensor]): List containing one tensor per input
            sequence. Each tensor is the same shape as returned by a call to
            torch.topk(subseq, ...) for that subsequence.
        values_per_batch (Optional[list[Tensor]]): List containing one tensor
            per input sequence. Like indices_per_batch, each tensor will be the
            same shape as returned by a standalone call torch.topk(subseq, ...).
            If `batch_topk` was originally called with `return_values=False`,
            then `values_per_batch` will be None.
    """
    # Normalize possibly negative dim
    dim = dim if dim >= 0 else dim + len(original_shape)

    indices_per_batch: list[Tensor] = []
    values_per_batch: Optional[list[Tensor]] = [] if result.values is not None else None

    # Compute the product of the other dims to determine subsequence topk size
    prod_other_dims = 1
    for i, s in enumerate(original_shape):
        if i not in (0, dim):
            prod_other_dims *= s

    for b in range(batch_offsets.numel() - 1):
        # Sub-range of the concatenated tensor
        start, end = int(batch_offsets[b]), int(batch_offsets[b + 1])
        seq_len_b = end - start

        # Slice into the flattened top-k output
        slice_start, slice_end = int(result.offsets[b]), int(result.offsets[b + 1])
        idx_flat_global = result.indices[slice_start:slice_end]

        # Convert to local coordinates when top-k was along the sequence dim
        idx_flat_local = idx_flat_global - start if dim == 0 else idx_flat_global

        # Derive k_b from the number of elements
        if idx_flat_local.numel() == 0:
            k_b = 0
        elif dim == 0:
            k_b = idx_flat_local.numel() // prod_other_dims
        else:
            k_b = idx_flat_local.numel() // (seq_len_b * prod_other_dims)

        # Build the full output shape for this subsequence
        out_shape = list(original_shape)
        if dim == 0:
            out_shape[0] = k_b
        else:
            out_shape[0] = seq_len_b
            out_shape[dim] = k_b

        # Reshape and store
        indices_per_batch.append(idx_flat_local.view(out_shape))

        if result.values is not None:
            vals = result.values
            assert vals is not None
            vals_b = vals[slice_start:slice_end]
            assert values_per_batch is not None
            values_per_batch.append(vals_b.view(out_shape))  # type: ignore

    return indices_per_batch, values_per_batch