Indexing utilties

The indexing module provides operations for selecting, slicing, and gather/scattering sparse tensors.

Basic Operations

sparse_select(tensor, axis, index)

Selects a single subtensor from a sparse tensor along the specified axis.

This function extracts a slice of a sparse tensor by selecting elements where the coordinate along the specified sparse axis matches the given index value. Unlike Pytorch's built-in indexing, this implementation properly handles gradient flow for sparse tensors.

Parameters:
  • tensor (Tensor) –

    The input sparse tensor.

  • axis (int) –

    The dimension along which to select the subtensor. Negative axes are supported.

  • index (int) –

    The index to select along the specified axis. Negative indices are supported.

Returns:
  • Tensor( Tensor ) –

    A new sparse tensor with one fewer dimension than the input tensor. The shape will be tensor.shape[:axis] + tensor.shape[axis+1:].

Raises:
  • ValueError

    If the input tensor is not sparse, or if the axis or index are out of bounds.

Examples:

Create a sparse tensor and select slices:

>>> # Create a sparse tensor with dimensions (3, 4, 2), where the last
>>> # dimension is a dense dimension
>>> i = torch.tensor([[0, 1], [1, 3], [2, 2]].T)
>>> v = torch.tensor([[1.0, 1.5], [2.0, 2.5], [3.0, 3.5]])
>>> x = torch.sparse_coo_tensor(i, v, (3, 4, 2))
>>>
>>> # Select values along sparse dimension 0
>>> slice_0 = sparse_select(x, 0, 1)  # Get elements where dim 0 == 1
>>> # slice_0 has shape (4, 2)
>>>
>>> # Select values along dense dimension 2
>>> slice_2 = sparse_select(x, 2, 0)  # Get elements where dim 2 == 0
>>> # slice_2 has shape (3, 4)
Source code in pytorch_sparse_utils/indexing/basics.py
@torch.jit.script
def sparse_select(tensor: Tensor, axis: int, index: int) -> Tensor:
    """Selects a single subtensor from a sparse tensor along the specified axis.

    This function extracts a slice of a sparse tensor by selecting elements
    where the coordinate along the specified sparse axis matches the given index value.
    Unlike Pytorch's built-in indexing, this implementation properly handles
    gradient flow for sparse tensors.

    Args:
        tensor (Tensor): The input sparse tensor.
        axis (int): The dimension along which to select the subtensor. Negative axes
            are supported.
        index (int): The index to select along the specified axis. Negative indices
            are supported.

    Returns:
        Tensor: A new sparse tensor with one fewer dimension than the input tensor.
            The shape will be tensor.shape[:axis] + tensor.shape[axis+1:].

    Raises:
        ValueError: If the input tensor is not sparse, or if the axis or index are
            out of bounds.

    Examples:
        Create a sparse tensor and select slices:

        >>> # Create a sparse tensor with dimensions (3, 4, 2), where the last
        >>> # dimension is a dense dimension
        >>> i = torch.tensor([[0, 1], [1, 3], [2, 2]].T)
        >>> v = torch.tensor([[1.0, 1.5], [2.0, 2.5], [3.0, 3.5]])
        >>> x = torch.sparse_coo_tensor(i, v, (3, 4, 2))
        >>>
        >>> # Select values along sparse dimension 0
        >>> slice_0 = sparse_select(x, 0, 1)  # Get elements where dim 0 == 1
        >>> # slice_0 has shape (4, 2)
        >>>
        >>> # Select values along dense dimension 2
        >>> slice_2 = sparse_select(x, 2, 0)  # Get elements where dim 2 == 0
        >>> # slice_2 has shape (3, 4)
    """
    if not tensor.is_sparse:
        raise ValueError("Input tensor is not sparse.")

    if not isinstance(index, int):
        raise ValueError(f"Expected integer index, got type {type(index)}")

    # Normalize negative axis
    orig_axis = axis
    if axis < 0:
        axis = tensor.ndim + axis

    # Validate axis
    if axis < 0 or axis >= tensor.ndim:
        raise ValueError(
            f"Axis {orig_axis} out of bounds for tensor with {tensor.ndim} dimensions"
        )

    # Normalize negative index
    orig_index = index
    if index < 0:
        index = tensor.size(axis) + index

    # check if index is in bounds
    if index < 0 or index >= tensor.size(axis):
        # weird string construction to work around torchscript not liking concatting
        # multiple f strings
        error_str = "Index " + str(orig_index) + " is out of bounds on axis "
        error_str += (
            str(orig_axis) + " for tensor with shape " + str(tensor.shape) + "."
        )
        raise ValueError(error_str)

    tensor = tensor.coalesce()

    sparse_dims = tensor.sparse_dim()
    if axis < sparse_dims:
        # Selection along a sparse dimension
        index_mask = tensor.indices()[axis] == index
        values = tensor.values()[index_mask]
        indices = torch.cat(
            [
                tensor.indices()[:axis, index_mask],
                tensor.indices()[axis + 1 :, index_mask],
            ]
        )
    else:
        # Selecting along a dense dimension
        # This means we just index the values tensor along the appropriate dim
        # and the sparse indices stay the same
        indices = tensor.indices()
        dense_axis = axis - sparse_dims + 1  # +1 because first dim of values is nnz
        values = tensor.values().select(dense_axis, index)

    return torch.sparse_coo_tensor(
        indices, values, tensor.shape[:axis] + tensor.shape[axis + 1 :]
    ).coalesce()

sparse_index_select(tensor, axis, index, check_bounds=True, disable_builtin_fallback=False)

Selects values from a sparse tensor along a specified dimension.

This function is equivalent to Tensor.index_select(axis, index) but offers full support for the backward pass for sparse tensors. It returns a new sparse tensor containing only the values at the specified indices along the given axis.

This function falls back to the built-in Tensor.index_select(axis, index) when gradients are not required. Benchmarking seems to indicate the built-in version is generally faster and more memory efficient except for some specialized situations on CUDA. You can always use the custom implementation by setting this function's input argument disable_builtin_fallback to True.

Note that the built-in Tensor.index_select will trigger Device-Side Assert errors if it is given indices outside the bounds of a sparse tensor. Unlike the built-in Tensor.index_select, this function validates that indices are within bounds (when check_bounds=True), making it a safer alternative even when gradient support isn't needed.

Parameters:
  • tensor (Tensor) –

    The input sparse tensor from which to select values.

  • axis (int) –

    The dimension along which to select values. Can be negative to index from the end.

  • index (Tensor) –

    The indices of the values to select along the specified dimension. Must be a 1D tensor or scalar of integer dtype.

  • check_bounds (bool, default: True ) –

    Whether to check if indices are within bounds. Set to False if indices are guaranteed to be in-bounds to avoid a CPU sync on CUDA tensors. Benchmarking shows the bounds check leads to an overhead of about 5% on cpu and 10% on cuda. Defaults to True.

  • disable_builtin_fallback (bool, default: False ) –

    Whether to always use the custom gradient-tracking version of index_select, even when gradients are not needed. Does nothing if the input sparse tensor does require gradients. Defaults to False.

Returns:
  • Tensor( Tensor ) –

    A new sparse tensor containing the selected values.

Raises:
  • ValueError
    • If the input tensor is not sparse.
    • If the index tensor has invalid shape or is not an integer tensor.
    • If the axis is out of bounds for tensor dimensions.
    • If check_bounds is True and the index tensor contains out-of-bounds indices.

Examples:

>>> # Create a sparse tensor with shape (4, 5)
>>> indices = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 4]])
>>> values = torch.tensor([10.0, 20.0, 30.0, 40.0])
>>> sparse_tensor = torch.sparse_coo_tensor(indices, values, (4, 5))
>>> # Select rows 1 and 3
>>> selected = sparse_index_select(sparse_tensor, axis=0, index=torch.tensor([1, 3]))
>>> selected.to_dense()
tensor([[ 0., 20.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0., 40.]])
>>> # Select columns 2, 3, and 4
>>> selected = sparse_index_select(sparse_tensor, axis=1, index=torch.tensor([2, 3, 4]))
>>> selected.to_dense()
tensor([[ 0.,  0.,  0.],
        [20.,  0.,  0.],
        [ 0., 30.,  0.],
        [ 0.,  0., 40.]])
>>> # Works with negative axis indexing
>>> selected = sparse_index_select(sparse_tensor, axis=-1, index=torch.tensor([4, 2]))
>>> selected.to_dense()
tensor([[ 0.,  0.],
        [ 0., 20.],
        [ 0.,  0.],
        [40., 30.]])
>>> # Gradient support example
>>> values = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
>>> indices = torch.tensor([[0, 1, 2], [0, 1, 0]])
>>> x = torch.sparse_coo_tensor(indices, values, (3, 2))
>>> selected = sparse_index_select(x, axis=0, index=torch.tensor([0, 2]))
>>> loss = selected.sum()
>>> loss.backward()
>>> values.grad  # Gradient flows back to original values
tensor([1., 0., 1.])
Source code in pytorch_sparse_utils/indexing/basics.py
@torch.jit.script
def sparse_index_select(
    tensor: Tensor,
    axis: int,
    index: Tensor,
    check_bounds: bool = True,
    disable_builtin_fallback: bool = False,
) -> Tensor:
    """Selects values from a sparse tensor along a specified dimension.

    This function is equivalent to Tensor.index_select(axis, index) but offers full
    support for the backward pass for sparse tensors. It returns a new sparse
    tensor containing only the values at the specified indices along the given axis.

    This function falls back to the built-in Tensor.index_select(axis, index)
    when gradients are not required. Benchmarking seems to indicate the built-in
    version is generally faster and more memory efficient except for some specialized
    situations on CUDA. You can always use the custom implementation by setting this
    function's input argument disable_builtin_fallback to True.

    Note that the built-in Tensor.index_select will trigger Device-Side Assert errors
    if it is given indices outside the bounds of a sparse tensor.
    Unlike the built-in Tensor.index_select, this function validates that indices
    are within bounds (when check_bounds=True), making it a safer alternative even
    when gradient support isn't needed.

    Args:
        tensor (Tensor): The input sparse tensor from which to select values.
        axis (int): The dimension along which to select values. Can be negative
            to index from the end.
        index (Tensor): The indices of the values to select along the specified
            dimension. Must be a 1D tensor or scalar of integer dtype.
        check_bounds (bool, optional): Whether to check if indices are within bounds.
            Set to False if indices are guaranteed to be in-bounds to avoid a CPU sync
            on CUDA tensors. Benchmarking shows the bounds check leads to an overhead
            of about 5% on cpu and 10% on cuda. Defaults to True.
        disable_builtin_fallback (bool, optional): Whether to always use the custom
            gradient-tracking version of index_select, even when gradients are not
            needed. Does nothing if the input sparse tensor does require gradients.
            Defaults to False.

    Returns:
        Tensor: A new sparse tensor containing the selected values.

    Raises:
        ValueError:
            - If the input tensor is not sparse.
            - If the index tensor has invalid shape or is not an integer tensor.
            - If the axis is out of bounds for tensor dimensions.
            - If check_bounds is True and the index tensor contains out-of-bounds
              indices.

    Examples:
        >>> # Create a sparse tensor with shape (4, 5)
        >>> indices = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 4]])
        >>> values = torch.tensor([10.0, 20.0, 30.0, 40.0])
        >>> sparse_tensor = torch.sparse_coo_tensor(indices, values, (4, 5))

        >>> # Select rows 1 and 3
        >>> selected = sparse_index_select(sparse_tensor, axis=0, index=torch.tensor([1, 3]))
        >>> selected.to_dense()
        tensor([[ 0., 20.,  0.,  0.,  0.],
                [ 0.,  0.,  0.,  0., 40.]])

        >>> # Select columns 2, 3, and 4
        >>> selected = sparse_index_select(sparse_tensor, axis=1, index=torch.tensor([2, 3, 4]))
        >>> selected.to_dense()
        tensor([[ 0.,  0.,  0.],
                [20.,  0.,  0.],
                [ 0., 30.,  0.],
                [ 0.,  0., 40.]])

        >>> # Works with negative axis indexing
        >>> selected = sparse_index_select(sparse_tensor, axis=-1, index=torch.tensor([4, 2]))
        >>> selected.to_dense()
        tensor([[ 0.,  0.],
                [ 0., 20.],
                [ 0.,  0.],
                [40., 30.]])

        >>> # Gradient support example
        >>> values = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
        >>> indices = torch.tensor([[0, 1, 2], [0, 1, 0]])
        >>> x = torch.sparse_coo_tensor(indices, values, (3, 2))
        >>> selected = sparse_index_select(x, axis=0, index=torch.tensor([0, 2]))
        >>> loss = selected.sum()
        >>> loss.backward()
        >>> values.grad  # Gradient flows back to original values
        tensor([1., 0., 1.])
    """
    if not tensor.is_sparse:
        raise ValueError("Input tensor must be sparse")

    # Validate index tensor
    if torch.is_floating_point(index):
        raise ValueError(f"Received index tensor of non-integer dtype: {index.dtype}")
    if index.ndim > 1:
        raise ValueError(f"Index tensor must be 0D or 1D, got {index.ndim}D")
    elif index.ndim == 0:
        index = index.unsqueeze(0)

    # Normalize negative axis
    orig_axis = axis
    if axis < 0:
        axis = tensor.ndim + axis

    # Validate axis
    if axis < 0 or axis >= tensor.ndim:
        raise ValueError(
            f"Axis {orig_axis} out of bounds for tensor with {tensor.ndim} dimensions"
        )

    # Validate index bounds (optional)
    if check_bounds and index.numel() > 0:
        out_of_bounds = ((index < 0) | (index >= tensor.shape[axis])).any()
        if out_of_bounds:  # cpu sync happens here
            raise ValueError(
                f"Index tensor has entries out of bounds for axis {orig_axis} with size {tensor.shape[axis]}"
            )

    if not tensor.requires_grad and not disable_builtin_fallback:
        # Fall back to built-in implementation
        return tensor.index_select(axis, index.long()).coalesce()

    tensor = tensor.coalesce()

    tensor_indices = tensor.indices()
    tensor_values = tensor.values()

    new_indices, new_values = _sparse_index_select_inner(
        tensor_indices, tensor_values, axis, index
    )

    new_shape = list(tensor.shape)
    new_shape[axis] = len(index)

    return torch.sparse_coo_tensor(new_indices.long(), new_values, new_shape).coalesce()

Bulk Indexing

batch_sparse_index(sparse_tensor, index_tensor, check_all_specified=False)

Batch selection of elements from a torch sparse tensor. The index tensor may have arbitrary batch dimensions.

If dense_tensor = sparse_tensor.to_sparse(), the equivalent dense indexing operation would be dense_tensor[index_tensor.unbind(-1)]

Parameters:
  • sparse_tensor (Tensor) –

    Sparse tensor of dimension [S0, S1, ..., Sn-1, D0, D1, ..., Dm-1], with n sparse dimensions and m dense dimensions.

  • index_tensor (Tensor) –

    Long tensor of dimension [B0, B1, ..., Bp-1, n] with optional p leading batch dimensions and final dimension corresponding to the sparse dimensions of the sparse tensor. Negative indices are not supported and will be considered unspecified.

  • check_all_specified (bool, default: False ) –

    If True, this function will raise a ValueError if any of the indices in index_tensor are not specified in sparse_tensor. If False, selections at unspecified indices will be returned with padding values of 0. Defaults to False.

Returns:
  • Tensor( Tensor ) –

    Tensor of dimension [B0, B1, ..., Bp-1, D0, D1, ..., Dm-1].

  • Tensor( Tensor ) –

    Boolean tensor of dimension [B0, B1, ..., Bp-1], where each element is True if the corresponding index is a specified (nonzero) element of the sparse tensor and False if not.

Raises:
  • ValueError

    If check_all_specified is set to True and not all indices in index_tensor had associated values specified in sparse_tensor, or if index_tensor is a nested tensor (feature planned but not implemented yet)

Examples:

>>> # Create a 3D sparse tensor with 2 sparse dims and 1 dense dim
>>> indices = torch.tensor([[0, 1, 2], [0, 1, 2]])
>>> values = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
>>> sparse = torch.sparse_coo_tensor(indices, values, (3, 3, 2))
>>> # Single index lookup
>>> index = torch.tensor([[1, 1]])
>>> result, mask = batch_sparse_index(sparse, index)
>>> result
tensor([[3., 4.]])
>>> mask
tensor([True])
>>> # Batch lookup with some unspecified indices
>>> batch_indices = torch.tensor([[0, 0], [1, 1], [2, 0]])  # Last one unspecified
>>> results, masks = batch_sparse_index(sparse, batch_indices)
>>> results
tensor([[1., 2.],
        [3., 4.],
        [0., 0.]])  # Zeros for unspecified index
>>> masks
tensor([ True,  True, False])
>>> # 2D batch of indices
>>> batch_indices = torch.tensor([[[0, 0], [2, 2]],
...                               [[1, 1], [0, 1]]])
>>> results, masks = batch_sparse_index(sparse, batch_indices)
>>> results.shape
torch.Size([2, 2, 2])
>>> masks
tensor([[ True,  True],
        [ True, False]])
>>> # Check all specified - will raise error if any index is missing
>>> try:
...     results, masks = batch_sparse_index(sparse, torch.tensor([[0, 2]]),
...                                         check_all_specified=True)
... except ValueError as e:
...     print("Error:", e)
Error: `check_all_specified` was set to True but not all gathered values were specified
Source code in pytorch_sparse_utils/indexing/basics.py
@torch.jit.script
def batch_sparse_index(
    sparse_tensor: Tensor, index_tensor: Tensor, check_all_specified: bool = False
) -> tuple[Tensor, Tensor]:
    """Batch selection of elements from a torch sparse tensor. The index tensor may
    have arbitrary batch dimensions.

    If dense_tensor = sparse_tensor.to_sparse(), the equivalent dense indexing
    operation would be dense_tensor[index_tensor.unbind(-1)]

    Args:
        sparse_tensor (Tensor): Sparse tensor of dimension
            [S0, S1, ..., Sn-1, D0, D1, ..., Dm-1], with n sparse dimensions and
            m dense dimensions.
        index_tensor (Tensor): Long tensor of dimension [B0, B1, ..., Bp-1, n]
            with optional p leading batch dimensions and final dimension corresponding
            to the sparse dimensions of the sparse tensor. Negative indices are not
            supported and will be considered unspecified.
        check_all_specified (bool): If True, this function will raise a
            ValueError if any of the indices in `index_tensor` are not specified
            in `sparse_tensor`. If False, selections at unspecified indices will be
            returned with padding values of 0. Defaults to False.

    Returns:
        Tensor: Tensor of dimension [B0, B1, ..., Bp-1, D0, D1, ..., Dm-1].
        Tensor: Boolean tensor of dimension [B0, B1, ..., Bp-1], where each element is
            True if the corresponding index is a specified (nonzero) element of the
            sparse tensor and False if not.

    Raises:
        ValueError: If `check_all_specified` is set to True and not all indices in
            `index_tensor` had associated values specified in `sparse_tensor`, or if
            `index_tensor` is a nested tensor (feature planned but not implemented yet)

    Examples:
        >>> # Create a 3D sparse tensor with 2 sparse dims and 1 dense dim
        >>> indices = torch.tensor([[0, 1, 2], [0, 1, 2]])
        >>> values = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
        >>> sparse = torch.sparse_coo_tensor(indices, values, (3, 3, 2))

        >>> # Single index lookup
        >>> index = torch.tensor([[1, 1]])
        >>> result, mask = batch_sparse_index(sparse, index)
        >>> result
        tensor([[3., 4.]])
        >>> mask
        tensor([True])

        >>> # Batch lookup with some unspecified indices
        >>> batch_indices = torch.tensor([[0, 0], [1, 1], [2, 0]])  # Last one unspecified
        >>> results, masks = batch_sparse_index(sparse, batch_indices)
        >>> results
        tensor([[1., 2.],
                [3., 4.],
                [0., 0.]])  # Zeros for unspecified index
        >>> masks
        tensor([ True,  True, False])

        >>> # 2D batch of indices
        >>> batch_indices = torch.tensor([[[0, 0], [2, 2]],
        ...                               [[1, 1], [0, 1]]])
        >>> results, masks = batch_sparse_index(sparse, batch_indices)
        >>> results.shape
        torch.Size([2, 2, 2])
        >>> masks
        tensor([[ True,  True],
                [ True, False]])

        >>> # Check all specified - will raise error if any index is missing
        >>> try:
        ...     results, masks = batch_sparse_index(sparse, torch.tensor([[0, 2]]),
        ...                                         check_all_specified=True)
        ... except ValueError as e:
        ...     print("Error:", e)
        Error: `check_all_specified` was set to True but not all gathered values were specified
    """
    if index_tensor.is_nested:
        raise ValueError("Nested index tensor not supported yet")

    sparse_tensor = sparse_tensor.coalesce()
    sparse_tensor_values = sparse_tensor.values()
    dense_dim = sparse_tensor.dense_dim()

    index_search, is_specified_mask = get_sparse_index_mapping(
        sparse_tensor, index_tensor
    )
    if check_all_specified and not is_specified_mask.all():
        raise ValueError(
            "`check_all_specified` was set to True but not all gathered values "
            "were specified"
        )

    selected: Tensor = gather_mask_and_fill(
        sparse_tensor_values, index_search, is_specified_mask
    )

    out_shape = index_tensor.shape[:-1]
    assert is_specified_mask.shape == out_shape
    if dense_dim > 0:
        out_shape = out_shape + (sparse_tensor.shape[-dense_dim:])

    assert selected.shape == out_shape

    return selected, is_specified_mask

Sparse Tensor Scatter

scatter_to_sparse_tensor(sparse_tensor, index_tensor, values, check_all_specified=False)

Batch updating of elements in a torch sparse tensor. Should be equivalent to sparse_tensor[index_tensor] = values. It works by flattening the sparse tensor's sparse dims and the index tensor to 1D (and converting n-d indices to raveled indices), then using index_copy along the flattened sparse tensor.

Parameters:
  • sparse_tensor (Tensor) –

    Sparse tensor of dimension [s0, s1, s2, ..., d0, d1, d2, ...]; where s0, s1, ... are S leading sparse dimensions and d0, d1, d2, ... are D dense dimensions.

  • index_tensor (Tensor) –

    Long tensor of dimension [b0, b1, b2, ..., S]; where b0, b1, b2, ... are B leading batch dimensions.

  • values (Tensor) –

    Tensor of dimension [b0, b1, b2, ... d0, d1, d2, ...], where dimensions are as above.

  • check_all_specified (bool, default: False ) –

    If True, this function will throw a ValueError if any of the indices specified in index_tensor are not already present in sparse_tensor. Default: False.

Returns:
  • Tensor( Tensor ) –

    sparse_tensor with the new values scattered into it

Notes

This function uses index_copy as the underlying mechanism to write new values, so duplicate indices in index_tensor will have the same result as other uses of index_copy, i.e., the result will depend on which copy occurs last. This imitates the behavior of scatter-like operations rather than the typical coalescing deduplication behavior of sparse tensors.

Examples:

>>> # Create a sparse tensor with values
>>> indices = torch.tensor([[0, 1, 2], [0, 1, 0]])
>>> values = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
>>> sparse = torch.sparse_coo_tensor(indices, values, (3, 3, 2))
>>> # Update existing values
>>> update_indices = torch.tensor([[0, 0], [1, 1]])
>>> new_values = torch.tensor([[10.0, 20.0], [30.0, 40.0]])
>>> updated = scatter_to_sparse_tensor(sparse, update_indices, new_values)
>>> updated.to_dense()
tensor([[[10., 20.],  # Updated
         [ 0.,  0.],
         [ 0.,  0.]],
        [[ 0.,  0.],
         [30., 40.],  # Updated
         [ 0.,  0.]],
        [[5.,  6.],   # Unchanged
         [ 0.,  0.],
         [ 0.,  0.]]])
>>> # Add new values (scatter to unspecified locations)
>>> new_indices = torch.tensor([[0, 2], [2, 2]])
>>> new_values = torch.tensor([[100.0, 200.0], [300.0, 400.0]])
>>> updated = scatter_to_sparse_tensor(sparse, new_indices, new_values)
>>> updated.to_dense()[0, 2]  # New value added
tensor([100., 200.])
>>> # Batch update with multiple indices
>>> batch_indices = torch.tensor([[[0, 0], [1, 1]],
...                               [[2, 0], [0, 1]]])
>>> batch_values = torch.tensor([[[11., 12.], [13., 14.]],
...                              [[15., 16.], [17., 18.]]])
>>> # Flatten batch dimensions
>>> flat_indices = batch_indices.reshape(-1, 2)
>>> flat_values = batch_values.reshape(-1, 2)
>>> updated = scatter_to_sparse_tensor(sparse, flat_indices, flat_values)
>>> # check_all_specified example
>>> indices = torch.tensor([[0, 0], [1, 1]])
>>> values = torch.tensor([1.0, 2.0])
>>> sparse = torch.sparse_coo_tensor(indices.T, values, (2, 2))
>>>
>>> # This will succeed (all indices exist)
>>> update_indices = torch.tensor([[0, 0]])
>>> update_values = torch.tensor([10.0])
>>> result = scatter_to_sparse_tensor(sparse, update_indices, update_values,
...                                   check_all_specified=True)
>>>
>>> # This will raise ValueError (index [1, 0] doesn't exist)
>>> try:
...     bad_indices = torch.tensor([[1, 0]])
...     bad_values = torch.tensor([20.0])
...     result = scatter_to_sparse_tensor(sparse, bad_indices, bad_values,
...                                       check_all_specified=True)
... except ValueError as e:
...     print("Error:", e)
Error: `check_all_specified` was set to True but not all gathered values were specified
Source code in pytorch_sparse_utils/indexing/scatter.py
def scatter_to_sparse_tensor(
    sparse_tensor: Tensor,
    index_tensor: Tensor,
    values: Tensor,
    check_all_specified: bool = False,
) -> Tensor:
    """Batch updating of elements in a torch sparse tensor. Should be
    equivalent to sparse_tensor[index_tensor] = values. It works by flattening
    the sparse tensor's sparse dims and the index tensor to 1D (and converting
    n-d indices to raveled indices), then using index_copy along the flattened
    sparse tensor.

    Args:
        sparse_tensor (Tensor): Sparse tensor of dimension
            [s0, s1, s2, ..., d0, d1, d2, ...]; where s0, s1, ... are
            S leading sparse dimensions and d0, d1, d2, ... are D dense dimensions.
        index_tensor (Tensor): Long tensor of dimension [b0, b1, b2, ..., S]; where
            b0, b1, b2, ... are B leading batch dimensions.
        values (Tensor): Tensor of dimension [b0, b1, b2, ... d0, d1, d2, ...], where
            dimensions are as above.
        check_all_specified (bool): If True, this function will throw a ValueError
            if any of the indices specified in index_tensor are not already present
            in sparse_tensor. Default: False.

    Returns:
        Tensor: sparse_tensor with the new values scattered into it

    Notes:
        This function uses index_copy as the underlying mechanism to write new values,
            so duplicate indices in index_tensor will have the same result as other
            uses of index_copy, i.e., the result will depend on which copy occurs last.
            This imitates the behavior of scatter-like operations rather than the
            typical coalescing deduplication behavior of sparse tensors.

    Examples:
        >>> # Create a sparse tensor with values
        >>> indices = torch.tensor([[0, 1, 2], [0, 1, 0]])
        >>> values = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
        >>> sparse = torch.sparse_coo_tensor(indices, values, (3, 3, 2))

        >>> # Update existing values
        >>> update_indices = torch.tensor([[0, 0], [1, 1]])
        >>> new_values = torch.tensor([[10.0, 20.0], [30.0, 40.0]])
        >>> updated = scatter_to_sparse_tensor(sparse, update_indices, new_values)
        >>> updated.to_dense()
        tensor([[[10., 20.],  # Updated
                 [ 0.,  0.],
                 [ 0.,  0.]],
                [[ 0.,  0.],
                 [30., 40.],  # Updated
                 [ 0.,  0.]],
                [[5.,  6.],   # Unchanged
                 [ 0.,  0.],
                 [ 0.,  0.]]])

        >>> # Add new values (scatter to unspecified locations)
        >>> new_indices = torch.tensor([[0, 2], [2, 2]])
        >>> new_values = torch.tensor([[100.0, 200.0], [300.0, 400.0]])
        >>> updated = scatter_to_sparse_tensor(sparse, new_indices, new_values)
        >>> updated.to_dense()[0, 2]  # New value added
        tensor([100., 200.])

        >>> # Batch update with multiple indices
        >>> batch_indices = torch.tensor([[[0, 0], [1, 1]],
        ...                               [[2, 0], [0, 1]]])
        >>> batch_values = torch.tensor([[[11., 12.], [13., 14.]],
        ...                              [[15., 16.], [17., 18.]]])
        >>> # Flatten batch dimensions
        >>> flat_indices = batch_indices.reshape(-1, 2)
        >>> flat_values = batch_values.reshape(-1, 2)
        >>> updated = scatter_to_sparse_tensor(sparse, flat_indices, flat_values)

        >>> # check_all_specified example
        >>> indices = torch.tensor([[0, 0], [1, 1]])
        >>> values = torch.tensor([1.0, 2.0])
        >>> sparse = torch.sparse_coo_tensor(indices.T, values, (2, 2))
        >>>
        >>> # This will succeed (all indices exist)
        >>> update_indices = torch.tensor([[0, 0]])
        >>> update_values = torch.tensor([10.0])
        >>> result = scatter_to_sparse_tensor(sparse, update_indices, update_values,
        ...                                   check_all_specified=True)
        >>>
        >>> # This will raise ValueError (index [1, 0] doesn't exist)
        >>> try:
        ...     bad_indices = torch.tensor([[1, 0]])
        ...     bad_values = torch.tensor([20.0])
        ...     result = scatter_to_sparse_tensor(sparse, bad_indices, bad_values,
        ...                                       check_all_specified=True)
        ... except ValueError as e:
        ...     print("Error:", e)
        Error: `check_all_specified` was set to True but not all gathered values were specified
    """
    if index_tensor.is_nested:
        assert values.is_nested
        index_tensor = torch.cat(index_tensor.unbind())
        values = torch.cat(values.unbind())

    dense_dim = sparse_tensor.dense_dim()
    sparse_dim = sparse_tensor.sparse_dim()
    values_batch_dims = values.shape[:-dense_dim] if dense_dim else values.shape
    if index_tensor.shape[:-1] != values_batch_dims:
        raise ValueError(
            "Expected matching batch dims for `index_tensor` and `values`, but got "
            f"batch dims {index_tensor.shape[:-1]} and "
            f"{values_batch_dims}, respectively."
        )

    sparse_tensor = sparse_tensor.coalesce()
    sparse_tensor_values = sparse_tensor.values()
    index_search, is_specified_mask = get_sparse_index_mapping(
        sparse_tensor, index_tensor, sanitize_linear_index_tensor=False
    )

    all_specified = torch.all(is_specified_mask)

    if check_all_specified and not all_specified:
        raise ValueError(
            "`check_all_specified` was set to True but not all gathered values "
            "were specified"
        )

    # In-place update of existing values
    if not sparse_tensor_values.requires_grad:
        updated_values = sparse_tensor_values.index_copy_(
            0, index_search[is_specified_mask], values[is_specified_mask]
        )
    else:
        updated_values = sparse_tensor_values.index_copy(
            0, index_search[is_specified_mask], values[is_specified_mask]
        )

    if all_specified:  # No new values to append: tensor is fully updated
        return torch.sparse_coo_tensor(
            sparse_tensor.indices(),
            updated_values,
            sparse_tensor.shape,
            dtype=sparse_tensor.dtype,
            device=sparse_tensor.device,
            is_coalesced=True,
        )

    # Need to append at least one new value: pre-sort the index tensor to save
    # the final coalesce operation

    # Pull out new values and indices to be added
    new_insert_pos: Tensor = index_search[~is_specified_mask]
    new_nd_indices = index_tensor[~is_specified_mask]
    new_values = values[~is_specified_mask]

    # Get sparse shape info for linearization
    sparse_sizes = torch.tensor(
        sparse_tensor.shape[:sparse_dim], device=sparse_tensor.device
    )

    if (new_nd_indices >= sparse_sizes.unsqueeze(0)).any():
        raise ValueError(
            "`index_tensor` has indices that are out of bounds of the original "
            f"sparse tensor's sparse shape ({sparse_sizes})."
        )

    # Obtain linearized versions of all indices for sorting
    old_indices_nd = sparse_tensor.indices()
    linear_offsets = _make_linear_offsets(sparse_sizes)
    new_indices_lin: Tensor = (new_nd_indices * linear_offsets).sum(-1)

    # Find duplicate linear indices
    unique_new_indices_lin, inverse = torch.unique(
        new_indices_lin, sorted=True, return_inverse=True
    )

    # Use inverse of indices unique to write to new values tensor and tensor of
    # insertion positions
    deduped_new_values = new_values.new_empty(
        (unique_new_indices_lin.size(0),) + new_values.shape[1:]
    )
    deduped_insert_pos = new_insert_pos.new_empty(unique_new_indices_lin.size(0))
    # Deciding which duplicate value wins is offloaded to index_copy
    deduped_new_values.index_copy_(0, inverse, new_values)
    deduped_insert_pos.index_copy_(0, inverse, new_insert_pos)

    # Convert uniqueified flattened indices to n-D for inclusion in indices tensor
    unique_new_indices_nd = unflatten_nd_indices(
        unique_new_indices_lin.unsqueeze(0), sparse_sizes, linear_offsets
    )

    # Concatenate old and new indices/values and sort
    combined_indices, combined_values = _merge_sorted(
        old_indices_nd,
        unique_new_indices_nd,
        updated_values,
        deduped_new_values,
        deduped_insert_pos,
    )

    return torch.sparse_coo_tensor(
        combined_indices,
        combined_values,
        sparse_tensor.shape,
        dtype=sparse_tensor.dtype,
        device=sparse_tensor.device,
        is_coalesced=True,
    )

Miscellaneous Functions

unique_rows(tensor, sorted=True)

Returns the indices of the unique rows (first dimension) of a 2D integer tensor.

Parameters:
  • tensor (Tensor) –

    A 2D tensor of integer type.

  • sorted (bool, default: True ) –

    Whether to sort the indices of unique rows before returning. If False, returned indices will be in lexicographic order of the rows.

Returns:
  • Tensor( Tensor ) –

    A 1D tensor whose elements are the indices of the unique rows of the input tensor, i.e., if the return tensor is inds, then tensor[inds] gives a 2D tensor of all unique rows of the input tensor.

Raises:
  • OverflowError

    If the tensor has values that are large enough to cause overflow errors when hashing each row to a single value.

Examples:

>>> tensor = torch.tensor([[1, 2, 3],
...                        [4, 5, 6],
...                        [1, 2, 3],  # Duplicate of row 0
...                        [7, 8, 9],
...                        [4, 5, 6]])  # Duplicate of row 1
>>> unique_indices = unique_rows(tensor)
>>> unique_indices
tensor([0, 1, 3])
>>> tensor[unique_indices]  # All unique rows
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])
Source code in pytorch_sparse_utils/indexing/unique.py
@torch.jit.script
def unique_rows(tensor: Tensor, sorted: bool = True) -> Tensor:
    """Returns the indices of the unique rows (first dimension) of a 2D integer tensor.

    Args:
        tensor (Tensor): A 2D tensor of integer type.
        sorted (bool): Whether to sort the indices of unique rows before returning.
            If False, returned indices will be in lexicographic order of the rows.

    Returns:
        Tensor: A 1D tensor whose elements are the indices of the unique rows of
            the input tensor, i.e., if the return tensor is `inds`, then
            tensor[inds] gives a 2D tensor of all unique rows of the input tensor.

    Raises:
        OverflowError: If the tensor has values that are large enough to cause overflow
            errors when hashing each row to a single value.

    Examples:
        >>> tensor = torch.tensor([[1, 2, 3],
        ...                        [4, 5, 6],
        ...                        [1, 2, 3],  # Duplicate of row 0
        ...                        [7, 8, 9],
        ...                        [4, 5, 6]])  # Duplicate of row 1
        >>> unique_indices = unique_rows(tensor)
        >>> unique_indices
        tensor([0, 1, 3])
        >>> tensor[unique_indices]  # All unique rows
        tensor([[1, 2, 3],
                [4, 5, 6],
                [7, 8, 9]])
    """
    if tensor.ndim != 2:
        raise ValueError(f"Expected a 2D tensor, got ndim={tensor.ndim}")
    if torch.is_floating_point(tensor) or torch.is_complex(tensor):
        raise ValueError(f"Expected integer tensor, got dtype={tensor.dtype}")

    max_vals = tensor.max(0).values
    min_vals = tensor.min(0).values

    # Check for overflow problems
    INT64_MAX = 9223372036854775807
    if (max_vals >= INT64_MAX).any():
        raise OverflowError(
            f"Tensor contains values at or near maximum int64 value ({INT64_MAX}), "
            "which would lead to overflow errors when computing unique rows."
        )

    log_sum = (max_vals + 1).log().sum()
    log_max = torch.tensor(INT64_MAX, device=max_vals.device).log()

    if log_sum > log_max:
        raise OverflowError(
            "Hashing rows would cause integer overflow. Maximum hashed row product is "
            f"approx {log_sum.exp()} compared to max int64 value of {INT64_MAX}."
        )

    # Handle negative values by shifting to nonnegative
    has_negs = min_vals < 0
    if has_negs.any():
        # Shift each column to be nonnegative
        neg_shift = torch.where(has_negs, min_vals, min_vals.new_zeros([]))
        tensor = tensor - neg_shift
        max_vals = max_vals - neg_shift

    tensor_flat, _ = flatten_nd_indices(tensor.T.long(), max_vals)
    tensor_flat: Tensor = tensor_flat.squeeze(0)

    unique_flat_indices, unique_inverse = torch.unique(tensor_flat, return_inverse=True)
    unique_row_indices: Tensor = unique_inverse.new_full(
        (unique_flat_indices.size(0),), tensor_flat.size(0)
    )
    unique_row_indices.scatter_reduce_(
        0,
        unique_inverse,
        torch.arange(tensor_flat.size(0), device=tensor.device),
        "amin",
    )
    if sorted:
        unique_row_indices = unique_row_indices.sort().values
    return unique_row_indices

union_sparse_indices(sparse_tensor_1, sparse_tensor_2)

Creates unified sparse tensors with the union of indices from both input tensors.

This function takes two sparse tensors and returns versions of them that share the same set of indices (the union of indices from both inputs). For indices present in only one of the tensors, zeros are filled in for the corresponding values in the other tensor.

This function is useful for ensuring a one-to-one correspondence between two sparse tensors' respective values() tensors, which in turn may be useful for elementwise value comparisons like loss functions.

Parameters:
  • sparse_tensor_1 (Tensor) –

    First sparse tensor.

  • sparse_tensor_2 (Tensor) –

    Second sparse tensor with the same sparse and dense dimensions as sparse_tensor_1.

Returns:
  • tuple[Tensor, Tensor]

    tuple[Tensor, Tensor]: A tuple containing: - tensor_1_unioned (Tensor): First tensor with indices expanded to include all indices from the second tensor (with zeros for missing values). - tensor_2_unioned (Tensor): Second tensor with indices expanded to include all indices from the first tensor (with zeros for missing values).

Raises:
  • ValueError

    If either input is not a sparse tensor or if the sparse and dense dimensions don't match between tensors.

Note

For very large sparse tensors, this operation may require significant memory for intermediate tensors.

Source code in pytorch_sparse_utils/indexing/misc.py
@torch.jit.script
def union_sparse_indices(
    sparse_tensor_1: Tensor, sparse_tensor_2: Tensor
) -> tuple[Tensor, Tensor]:
    """Creates unified sparse tensors with the union of indices from both input tensors.

    This function takes two sparse tensors and returns versions of them that share the
    same set of indices (the union of indices from both inputs). For indices present in
    only one of the tensors, zeros are filled in for the corresponding values in the
    other tensor.

    This function is useful for ensuring a one-to-one correspondence between two
    sparse tensors' respective values() tensors, which in turn may be useful for
    elementwise value comparisons like loss functions.

    Args:
        sparse_tensor_1 (Tensor): First sparse tensor.
        sparse_tensor_2 (Tensor): Second sparse tensor with the same sparse and dense
            dimensions as sparse_tensor_1.

    Returns:
        tuple[Tensor, Tensor]: A tuple containing:
            - tensor_1_unioned (Tensor): First tensor with indices expanded to include
                all indices from the second tensor (with zeros for missing values).
            - tensor_2_unioned (Tensor): Second tensor with indices expanded to include
                all indices from the first tensor (with zeros for missing values).

    Raises:
        ValueError: If either input is not a sparse tensor or if the sparse and
            dense dimensions don't match between tensors.

    Note:
        For very large sparse tensors, this operation may require significant memory
        for intermediate tensors.
    """
    if not sparse_tensor_1.is_sparse or not sparse_tensor_2.is_sparse:
        raise ValueError(
            "Expected two sparse tensors; got "
            f"{__sparse_or_dense(sparse_tensor_1)} and {__sparse_or_dense(sparse_tensor_2)}"
        )
    if sparse_tensor_1.shape != sparse_tensor_2.shape:
        raise ValueError(
            "Expected tensors to have same shapes; got "
            f"{sparse_tensor_1.shape} and {sparse_tensor_2.shape}"
        )
    if sparse_tensor_1.sparse_dim() != sparse_tensor_2.sparse_dim():
        raise ValueError(
            "Expected both sparse tensors to have equal numbers of sparse dims; got "
            f"{sparse_tensor_1.sparse_dim()} and {sparse_tensor_2.sparse_dim()}"
        )
    if sparse_tensor_1.dense_dim() != sparse_tensor_2.dense_dim():
        raise ValueError(
            "Expected both sparse tensors to have equal numbers of dense dims; got "
            f"{sparse_tensor_1.dense_dim()} and {sparse_tensor_2.dense_dim()}"
        )

    M = sparse_tensor_1.sparse_dim()
    K = sparse_tensor_1.dense_dim()

    sparse_tensor_1 = sparse_tensor_1.coalesce()
    sparse_tensor_2 = sparse_tensor_2.coalesce()

    indices_1, values_1 = sparse_tensor_1.indices(), sparse_tensor_1.values()
    indices_2, values_2 = sparse_tensor_2.indices(), sparse_tensor_2.values()

    # Need to find all indices that are unique to each sparse tensor
    # To do this, stack one of them twice and the other once
    indices_2_2_1 = torch.cat([indices_2, indices_2, indices_1], -1)
    uniques, counts = torch.unique(indices_2_2_1, dim=-1, return_counts=True)
    # Any that appear twice in the stacked indices are unique to tensor 2
    # and any that appear once are unique to tensor 1
    # (indices that appear 3x are shared already)
    indices_only_in_tensor_1 = uniques[:, counts == 1]
    indices_only_in_tensor_2 = uniques[:, counts == 2]

    # Figure out how many new indices will be added to each sparse tensor
    n_exclusives_1 = indices_only_in_tensor_1.size(-1)
    n_exclusives_2 = indices_only_in_tensor_2.size(-1)

    # Make zero-padding for new values tensors
    pad_zeros_1 = values_1.new_zeros(
        (n_exclusives_2,) + sparse_tensor_1.shape[M : M + K]
    )
    pad_zeros_2 = values_2.new_zeros(
        (n_exclusives_1,) + sparse_tensor_1.shape[M : M + K]
    )

    # Make the new tensors by stacking indices and values together
    tensor_1_unioned = torch.sparse_coo_tensor(
        torch.cat([indices_1, indices_only_in_tensor_2], -1),
        torch.cat([values_1, pad_zeros_1], 0),
        size=sparse_tensor_1.shape,
        device=sparse_tensor_1.device,
    ).coalesce()

    tensor_2_unioned = torch.sparse_coo_tensor(
        torch.cat([indices_2, indices_only_in_tensor_1], -1),
        torch.cat([values_2, pad_zeros_2], 0),
        size=sparse_tensor_2.shape,
        device=sparse_tensor_2.device,
    ).coalesce()

    if not torch.equal(tensor_1_unioned.indices(), tensor_2_unioned.indices()):
        raise RuntimeError("Internal error: unioned tensors have different indices")

    return tensor_1_unioned, tensor_2_unioned

Indexing Helpers

flatten_nd_indices(indices, sizes)

Flattens N-dimensional indices into 1-dimensional scalar indices.

Similar to np.ravel_multi_index but with slightly different API.

Parameters:
  • indices (Tensor) –

    Integer coordinate tensor of shape [N, B], where N is the number of dimensions to be flattened and B is the batch dimension.

  • sizes (Tensor) –

    Extents of every dimension, of shape [N]

Returns:
  • flat_indices( Tensor ) –

    Flattened indices tensor, of shape [1, B]

  • offsets( Tensor ) –

    Strides that were used for flattening (needed for unflatten), of shape [N]

Examples:

>>> # 2D indices to 1D
>>> indices = torch.tensor([[0, 1, 2],  # row indices
...                         [0, 2, 1]]) # col indices
>>> sizes = torch.tensor([3, 3])  # 3x3 grid
>>> flat, offsets = flatten_nd_indices(indices, sizes)
>>> flat
tensor([[0, 5, 7]])  # 0*3+0=0, 1*3+2=5, 2*3+1=7
>>> offsets
tensor([3, 1])
>>> # 3D indices to 1D
>>> indices = torch.tensor([[0, 1],    # dim 0
...                         [1, 0],    # dim 1
...                         [2, 1]])   # dim 2
>>> sizes = torch.tensor([2, 2, 3])  # 2x2x3 tensor
>>> flat, offsets = flatten_nd_indices(indices, sizes)
>>> flat
tensor([[5, 7]])  # 0*6+1*3+2=5, 1*6+0*3+1=7
>>> offsets
tensor([6, 3, 1])
Source code in pytorch_sparse_utils/indexing/utils.py
@torch.jit.script
def flatten_nd_indices(indices: Tensor, sizes: Tensor) -> tuple[Tensor, Tensor]:
    """Flattens N-dimensional indices into 1-dimensional scalar indices.

    Similar to np.ravel_multi_index but with slightly different API.

    Args:
        indices (Tensor): Integer coordinate tensor of shape [N, B], where N
            is the number of dimensions to be flattened and B is the batch dimension.
        sizes (Tensor): Extents of every dimension, of shape [N]

    Returns:
        flat_indices (Tensor): Flattened indices tensor, of shape [1, B]
        offsets (Tensor): Strides that were used for flattening (needed for unflatten),
            of shape [N]

    Examples:
        >>> # 2D indices to 1D
        >>> indices = torch.tensor([[0, 1, 2],  # row indices
        ...                         [0, 2, 1]]) # col indices
        >>> sizes = torch.tensor([3, 3])  # 3x3 grid
        >>> flat, offsets = flatten_nd_indices(indices, sizes)
        >>> flat
        tensor([[0, 5, 7]])  # 0*3+0=0, 1*3+2=5, 2*3+1=7
        >>> offsets
        tensor([3, 1])

        >>> # 3D indices to 1D
        >>> indices = torch.tensor([[0, 1],    # dim 0
        ...                         [1, 0],    # dim 1
        ...                         [2, 1]])   # dim 2
        >>> sizes = torch.tensor([2, 2, 3])  # 2x2x3 tensor
        >>> flat, offsets = flatten_nd_indices(indices, sizes)
        >>> flat
        tensor([[5, 7]])  # 0*6+1*3+2=5, 1*6+0*3+1=7
        >>> offsets
        tensor([6, 3, 1])
    """
    offsets = _make_linear_offsets(sizes)  # [N]
    flat_indices = (indices * offsets.unsqueeze(-1)).sum(0, keepdim=True)
    return flat_indices, offsets

unflatten_nd_indices(flat_indices, dim_sizes, offsets=None)

Reconstructs ('unflattens') N-D indices from 1D 'flattened' indices.

Parameters:
  • flat_indices (Tensor) –

    Flat indices tensor of shape [1, B]

  • dim_sizes (Tensor) –

    Original sizes of every dimension, of shape [N]

  • offsets (Optional[Tensor], default: None ) –

    Offsets that were used for flattening, as returned by _make_linear_offsets or flatten_nd_indices. If None, it will be recalculated from dim_sizes

Returns:
  • Tensor( Tensor ) –

    N-D indices tensor of shape [N, B]

Examples:

>>> # Unflatten 1D to 2D indices (inverse of flatten)
>>> flat_indices = torch.tensor([[0, 5, 7]])
>>> dim_sizes = torch.tensor([3, 3])  # 3x3 grid
>>> nd_indices = unflatten_nd_indices(flat_indices, dim_sizes)
>>> nd_indices
tensor([[0, 1, 2],  # row indices
        [0, 2, 1]]) # col indices
>>> # Unflatten to 3D indices
>>> flat_indices = torch.tensor([[5, 7, 23]])
>>> dim_sizes = torch.tensor([2, 3, 4])  # 2x3x4 tensor
>>> nd_indices = unflatten_nd_indices(flat_indices, dim_sizes)
>>> nd_indices
tensor([[0, 0, 1],  # dim 0
        [1, 1, 2],  # dim 1
        [1, 3, 3]]) # dim 2
>>> # Using precomputed offsets (more efficient for repeated calls)
>>> dim_sizes = torch.tensor([4, 5, 6])
>>> offsets = _make_linear_offsets(dim_sizes)  # [30, 6, 1]
>>> flat_indices = torch.tensor([[0, 31, 119]])
>>> nd_indices = unflatten_nd_indices(flat_indices, dim_sizes, offsets)
>>> nd_indices
tensor([[0, 1, 3],
        [0, 0, 4],
        [0, 1, 5]])
Source code in pytorch_sparse_utils/indexing/utils.py
@torch.jit.script
def unflatten_nd_indices(
    flat_indices: Tensor, dim_sizes: Tensor, offsets: Optional[Tensor] = None
) -> Tensor:
    """Reconstructs ('unflattens') N-D indices from 1D 'flattened' indices.

    Args:
        flat_indices (Tensor): Flat indices tensor of shape [1, B]
        dim_sizes (Tensor): Original sizes of every dimension, of shape [N]
        offsets (Optional[Tensor]): Offsets that were used for flattening, as returned
            by _make_linear_offsets or flatten_nd_indices. If None, it will be
            recalculated from `dim_sizes`

    Returns:
        Tensor: N-D indices tensor of shape [N, B]

    Examples:
        >>> # Unflatten 1D to 2D indices (inverse of flatten)
        >>> flat_indices = torch.tensor([[0, 5, 7]])
        >>> dim_sizes = torch.tensor([3, 3])  # 3x3 grid
        >>> nd_indices = unflatten_nd_indices(flat_indices, dim_sizes)
        >>> nd_indices
        tensor([[0, 1, 2],  # row indices
                [0, 2, 1]]) # col indices

        >>> # Unflatten to 3D indices
        >>> flat_indices = torch.tensor([[5, 7, 23]])
        >>> dim_sizes = torch.tensor([2, 3, 4])  # 2x3x4 tensor
        >>> nd_indices = unflatten_nd_indices(flat_indices, dim_sizes)
        >>> nd_indices
        tensor([[0, 0, 1],  # dim 0
                [1, 1, 2],  # dim 1
                [1, 3, 3]]) # dim 2

        >>> # Using precomputed offsets (more efficient for repeated calls)
        >>> dim_sizes = torch.tensor([4, 5, 6])
        >>> offsets = _make_linear_offsets(dim_sizes)  # [30, 6, 1]
        >>> flat_indices = torch.tensor([[0, 31, 119]])
        >>> nd_indices = unflatten_nd_indices(flat_indices, dim_sizes, offsets)
        >>> nd_indices
        tensor([[0, 1, 3],
                [0, 0, 4],
                [0, 1, 5]])
    """
    if offsets is None:
        offsets = _make_linear_offsets(dim_sizes)
    assert offsets is not None
    N = dim_sizes.numel()
    B = flat_indices.size(-1)
    out = torch.empty(N, B, device=flat_indices.device, dtype=torch.long)

    # integer divide by stride
    torch.div(
        flat_indices.expand_as(out),
        offsets.unsqueeze(1),
        rounding_mode="floor",
        out=out,
    )

    # modulus by sizes
    torch.remainder(out, dim_sizes.unsqueeze(1), out=out)

    return out

flatten_sparse_indices(tensor, start_axis, end_axis)

Flattens a sparse tensor's indices along specified dimensions.

This function takes a sparse tensor and flattens its indices along a contiguous range of dimension. It returns the new indices, the corresponding new shape, and the linear offsets used in the flattening process.

Parameters:
  • tensor (Tensor) –

    The input tensor. Its indices are expected to be in COO format.

  • start_axis (int) –

    Starting axis (inclusive) of the dimensions to flatten.

  • end_axis (int) –

    Ending axis (inclusive) of the dimensions to flatten.

Returns:
  • Tensor
    • new_indices (Tensor): The flattened indices of shape (D, N), where D is the number of dimensions in the flattened tensor and N is the number of nonzero elements.
  • Tensor
    • new_shape (Tensor): The new shape of the flattened tensor of shape (D,)
  • Tensor
    • dim_linear_offsets (Tensor): The linear offsets used during flattening, of shape (K,), where K is the number of flattened dimensions.

Examples:

>>> # Create a 3D sparse tensor
>>> indices = torch.tensor([[0, 1, 2], [0, 1, 0], [1, 2, 0]])
>>> values = torch.tensor([1.0, 2.0, 3.0])
>>> sparse = torch.sparse_coo_tensor(indices, values, (3, 3, 3))
>>> # Flatten dimensions 0 and 1 (first two dimensions)
>>> new_indices, new_shape, offsets = flatten_sparse_indices(sparse, 0, 1)
>>> new_indices
tensor([[0, 4, 6],  # Flattened indices for dims 0-1
        [1, 2, 0]]) # Unchanged indices for dim 2
>>> new_shape
tensor([9, 3])  # 3x3 flattened to 9, last dim unchanged
>>> offsets
tensor([3, 1])
>>> # Flatten all dimensions
>>> new_indices, new_shape, offsets = flatten_sparse_indices(sparse, 0, 2)
>>> new_indices
tensor([[1, 14, 18]])  # All indices flattened to 1D
>>> new_shape
tensor([27])  # 3x3x3 = 27
>>> # Create a 4D sparse tensor and flatten middle dimensions
>>> indices = torch.tensor([[0, 1], [1, 0], [2, 1], [0, 2]])
>>> values = torch.tensor([1.0, 2.0])
>>> sparse = torch.sparse_coo_tensor(indices, values, (2, 2, 3, 3))
>>>
>>> # Flatten dimensions 1 and 2
>>> new_indices, new_shape, offsets = flatten_sparse_indices(sparse, 1, 2)
>>> new_indices
tensor([[0, 1],  # Dim 0 unchanged
        [5, 1],  # Dims 1-2 flattened: 1*3+2=5, 0*3+1=1
        [0, 2]]) # Dim 3 unchanged
>>> new_shape
tensor([2, 6, 3])  # Shape is now [2, 2*3, 3]
Source code in pytorch_sparse_utils/indexing/utils.py
@torch.jit.script
def flatten_sparse_indices(
    tensor: Tensor, start_axis: int, end_axis: int
) -> tuple[Tensor, Tensor, Tensor]:
    """Flattens a sparse tensor's indices along specified dimensions.

    This function takes a sparse tensor and flattens its indices along a
    contiguous range of dimension. It returns the new indices, the
    corresponding new shape, and the linear offsets used in the flattening
    process.

    Args:
        tensor (Tensor): The input tensor. Its indices are expected to be in COO format.
        start_axis (int): Starting axis (inclusive) of the dimensions to flatten.
        end_axis (int): Ending axis (inclusive) of the dimensions to flatten.

    Returns:
        - new_indices (Tensor): The flattened indices of shape (D, N), where D is
            the number of dimensions in the flattened tensor and N is the number
            of nonzero elements.
        - new_shape (Tensor): The new shape of the flattened tensor of shape (D,)
        - dim_linear_offsets (Tensor): The linear offsets used during flattening,
            of shape (K,), where K is the number of flattened dimensions.

    Examples:
        >>> # Create a 3D sparse tensor
        >>> indices = torch.tensor([[0, 1, 2], [0, 1, 0], [1, 2, 0]])
        >>> values = torch.tensor([1.0, 2.0, 3.0])
        >>> sparse = torch.sparse_coo_tensor(indices, values, (3, 3, 3))

        >>> # Flatten dimensions 0 and 1 (first two dimensions)
        >>> new_indices, new_shape, offsets = flatten_sparse_indices(sparse, 0, 1)
        >>> new_indices
        tensor([[0, 4, 6],  # Flattened indices for dims 0-1
                [1, 2, 0]]) # Unchanged indices for dim 2
        >>> new_shape
        tensor([9, 3])  # 3x3 flattened to 9, last dim unchanged
        >>> offsets
        tensor([3, 1])

        >>> # Flatten all dimensions
        >>> new_indices, new_shape, offsets = flatten_sparse_indices(sparse, 0, 2)
        >>> new_indices
        tensor([[1, 14, 18]])  # All indices flattened to 1D
        >>> new_shape
        tensor([27])  # 3x3x3 = 27

        >>> # Create a 4D sparse tensor and flatten middle dimensions
        >>> indices = torch.tensor([[0, 1], [1, 0], [2, 1], [0, 2]])
        >>> values = torch.tensor([1.0, 2.0])
        >>> sparse = torch.sparse_coo_tensor(indices, values, (2, 2, 3, 3))
        >>>
        >>> # Flatten dimensions 1 and 2
        >>> new_indices, new_shape, offsets = flatten_sparse_indices(sparse, 1, 2)
        >>> new_indices
        tensor([[0, 1],  # Dim 0 unchanged
                [5, 1],  # Dims 1-2 flattened: 1*3+2=5, 0*3+1=1
                [0, 2]]) # Dim 3 unchanged
        >>> new_shape
        tensor([2, 6, 3])  # Shape is now [2, 2*3, 3]
    """
    tensor_indices = tensor.indices()  # sparse_dim x nnz (counterintuitive)
    indices_to_flatten = tensor_indices[start_axis : end_axis + 1]

    # convert shape to tensor since we will be doing math on it.
    # it needs to be on the same device as the sparse tensor rather than
    # staying on cpu because downstream tensors will be interacting with
    # the sparse tensor's indices tensor
    shape = torch._shape_as_tensor(tensor).to(tensor.device)

    sizes_to_flatten = shape[start_axis : end_axis + 1]

    flattened_indices, dim_linear_offsets = flatten_nd_indices(
        indices_to_flatten, sizes_to_flatten
    )

    # make new shape with the flattened axes stacked together
    new_shape = torch.cat(
        [
            shape[:start_axis],
            torch.prod(sizes_to_flatten, 0, keepdim=True),
            shape[end_axis + 1 :],
        ]
    )
    # this assertion shouldn't cause a cpu sync
    assert new_shape.size(0) == tensor.ndim - (end_axis - start_axis)

    # plug the flattened indices into the existing indices
    new_indices = torch.cat(
        [tensor_indices[:start_axis], flattened_indices, tensor_indices[end_axis + 1 :]]
    )
    return new_indices, new_shape, dim_linear_offsets

linearize_sparse_and_index_tensors(sparse_tensor, index_tensor)

Converts multidimensional indices of a sparse tensor and a tensor of indices that we want to retrieve to a shared linearized (flattened) format suitable for fast lookup.

Parameters:
  • sparse_tensor (Tensor) –

    torch.sparse_coo_tensor with indices to linearize.

  • index_tensor (Tensor) –

    Dense tensor with indices matching sparse_tensor's sparse dims. Can be of any dimension as long as the last dimension has length equal to the sparse tensor's sparse dimension.

Raises:
  • ValueError

    If the index tensor has a different last dimension than the sparse tensor's sparse dim.

Returns:
  • sparse_tensor_indices_linear( Tensor ) –

    Linearized version of sparse_tensor.indices().

  • index_tensor_linearized( Tensor ) –

    Linearized version of index_tensor with the last dimension squeezed out.

Examples:

>>> # Create a 2D sparse tensor
>>> indices = torch.tensor([[0, 1, 2], [0, 2, 1]])
>>> values = torch.tensor([1.0, 2.0, 3.0])
>>> sparse = torch.sparse_coo_tensor(indices, values, (3, 3))
>>> # Indices to look up
>>> lookup_indices = torch.tensor([[0, 0], [1, 2], [2, 1]])
>>>
>>> sparse_lin, lookup_lin = linearize_sparse_and_index_tensors(sparse, lookup_indices)
>>> sparse_lin
tensor([0, 5, 7])  # 0*3+0=0, 1*3+2=5, 2*3+1=7
>>> lookup_lin
tensor([0, 5, 7])  # Same linearization scheme
>>> # Batch of indices with different shape
>>> batch_indices = torch.tensor([[[0, 0], [1, 1]],
...                               [[2, 2], [0, 2]]])
>>> sparse_lin, batch_lin = linearize_sparse_and_index_tensors(sparse, batch_indices)
>>> batch_lin
tensor([0, 4, 8, 2])  # Flattened: 0, 1*3+1=4, 2*3+2=8, 0*3+2=2
>>> # 3D sparse tensor
>>> indices = torch.tensor([[0, 1], [0, 1], [1, 0]])
>>> values = torch.tensor([1.0, 2.0])
>>> sparse = torch.sparse_coo_tensor(indices, values, (2, 2, 2))
>>>
>>> lookup = torch.tensor([[0, 0, 1], [1, 1, 0]])
>>> sparse_lin, lookup_lin = linearize_sparse_and_index_tensors(sparse, lookup)
>>> sparse_lin
tensor([1, 6])  # 0*4+0*2+1=1, 1*4+1*2+0=6
>>> lookup_lin
tensor([1, 6])
Source code in pytorch_sparse_utils/indexing/utils.py
@torch.jit.script
def linearize_sparse_and_index_tensors(
    sparse_tensor: Tensor, index_tensor: Tensor
) -> tuple[Tensor, Tensor]:
    """Converts multidimensional indices of a sparse tensor and a tensor of indices
    that we want to retrieve to a shared linearized (flattened) format suitable
    for fast lookup.

    Args:
        sparse_tensor (Tensor): torch.sparse_coo_tensor with indices to linearize.
        index_tensor (Tensor): Dense tensor with indices matching sparse_tensor's
            sparse dims. Can be of any dimension as long as the last dimension
            has length equal to the sparse tensor's sparse dimension.

    Raises:
        ValueError: If the index tensor has a different last dimension than the
            sparse tensor's sparse dim.

    Returns:
        sparse_tensor_indices_linear (Tensor): Linearized version of
            sparse_tensor.indices().
        index_tensor_linearized (Tensor): Linearized version of index_tensor
            with the last dimension squeezed out.

    Examples:
        >>> # Create a 2D sparse tensor
        >>> indices = torch.tensor([[0, 1, 2], [0, 2, 1]])
        >>> values = torch.tensor([1.0, 2.0, 3.0])
        >>> sparse = torch.sparse_coo_tensor(indices, values, (3, 3))

        >>> # Indices to look up
        >>> lookup_indices = torch.tensor([[0, 0], [1, 2], [2, 1]])
        >>>
        >>> sparse_lin, lookup_lin = linearize_sparse_and_index_tensors(sparse, lookup_indices)
        >>> sparse_lin
        tensor([0, 5, 7])  # 0*3+0=0, 1*3+2=5, 2*3+1=7
        >>> lookup_lin
        tensor([0, 5, 7])  # Same linearization scheme

        >>> # Batch of indices with different shape
        >>> batch_indices = torch.tensor([[[0, 0], [1, 1]],
        ...                               [[2, 2], [0, 2]]])
        >>> sparse_lin, batch_lin = linearize_sparse_and_index_tensors(sparse, batch_indices)
        >>> batch_lin
        tensor([0, 4, 8, 2])  # Flattened: 0, 1*3+1=4, 2*3+2=8, 0*3+2=2

        >>> # 3D sparse tensor
        >>> indices = torch.tensor([[0, 1], [0, 1], [1, 0]])
        >>> values = torch.tensor([1.0, 2.0])
        >>> sparse = torch.sparse_coo_tensor(indices, values, (2, 2, 2))
        >>>
        >>> lookup = torch.tensor([[0, 0, 1], [1, 1, 0]])
        >>> sparse_lin, lookup_lin = linearize_sparse_and_index_tensors(sparse, lookup)
        >>> sparse_lin
        tensor([1, 6])  # 0*4+0*2+1=1, 1*4+1*2+0=6
        >>> lookup_lin
        tensor([1, 6])
    """
    if index_tensor.shape[-1] != sparse_tensor.sparse_dim():
        if (
            sparse_tensor.sparse_dim() - 1 == index_tensor.shape[-1]
            and sparse_tensor.shape[-1] == 1
            and sparse_tensor.dense_dim() == 0
        ):
            # handle case where there's a length-1 trailing sparse dim and the
            # index tensor ignores it
            sparse_tensor = sparse_tensor[..., 0].coalesce()
        else:
            raise ValueError(
                "Expected last dim of `index_tensor` to be the same as "
                "`sparse_tensor.sparse_dim()`, got "
                f"{str(index_tensor.shape[-1])} and {sparse_tensor.sparse_dim()}, "
                "respectively."
            )

    sparse_tensor_indices_linear, _, dim_linear_offsets = flatten_sparse_indices(
        sparse_tensor, 0, sparse_tensor.sparse_dim() - 1
    )
    sparse_tensor_indices_linear.squeeze_(0)

    # repeat the index flattening for the index tensor. The sparse tensor's indices
    # were already flattened in __flattened_indices
    index_tensor_linearized = (index_tensor * dim_linear_offsets).sum(-1).view(-1)

    return (
        sparse_tensor_indices_linear,
        index_tensor_linearized,
    )

get_sparse_index_mapping(sparse_tensor, index_tensor, sanitize_linear_index_tensor=True)

Finds the locations along a sparse tensor's values tensor for specified sparse indices. Also returns a mask indicating which indices have values actually present in the sparse tensor. It works by flattening the sparse tensor's sparse dims and the index tensor to 1D (and converting n-d indices to raveled indices), then using searchsorted along the flattened sparse tensor indices.

Parameters:
  • sparse_tensor (Tensor) –

    Sparse tensor of dimension ..., M; where ... are S leading sparse dimensions and M is the dense dimension.

  • index_tensor (Tensor) –

    Long tensor of dimension ..., S; where ... are leading batch dimensions. Negative indices and indices outside the bounds of the sparse dimensions are not supported and will be considered unspecified, with the corresponding entry in is_specified_mask being set to False.

  • sanitize_linear_index_tensor (bool, default: True ) –

    If False, then the output values at linear_index_tensor[~is_specified_mask] will be the "insertion position" that would keep the sparse tensor's indices ordered. This is useful if you want to insert values, but means that sparse_tensor.values()[linear_index_tensor] will be potentially unsafe if some of the "insertion position" values are out of bounds. If this arg is True, linear_index_tensor[~is_specified_mask] values will be set to 0. Defaults to True.

Returns:
  • linear_index_tensor( Tensor ) –

    Long tensor of dimension ... of the locations in sparse_tensor.values() corresponding to the indices in index_tensor. Elements where is_specified_mask is False are handled according to the value of sanitize_linear_index_tensor.

  • is_specified_mask( Tensor ) –

    Boolean tensor of dimension ... that is True for indices in index_tensor where values where actually specified in the sparse tensor and False for indices that were unspecified in the sparse tensor.

Examples:

>>> # Create a sparse tensor
>>> indices = torch.tensor([[0, 1, 2], [0, 1, 0]])
>>> values = torch.tensor([10.0, 20.0, 30.0])
>>> sparse = torch.sparse_coo_tensor(indices, values, (3, 3))
>>> # Look up existing indices
>>> lookup = torch.tensor([[0, 0], [1, 1], [2, 0]])
>>> positions, mask = get_sparse_index_mapping(sparse, lookup)
>>> positions
tensor([0, 1, 2])  # Positions in values tensor
>>> mask
tensor([True, True, True])  # All indices exist
>>> sparse.values()[positions]
tensor([10., 20., 30.])
>>> # Look up mix of existing and non-existing indices
>>> lookup = torch.tensor([[0, 0], [0, 1], [1, 0]])
>>> positions, mask = get_sparse_index_mapping(sparse, lookup)
>>> positions
tensor([0, 0, 0])  # Non-existing indices mapped to 0 (sanitized)
>>> mask
tensor([True, False, False])  # Only first index exists
>>> # With sanitize=False to get insertion positions
>>> positions, mask = get_sparse_index_mapping(sparse, lookup,
...                                            sanitize_linear_index_tensor=False)
>>> positions
tensor([0, 1, 1])  # Position 1 is where [0,1] and [1,0] would be inserted
>>> mask
tensor([True, False, False])
>>> # Batch lookup
>>> batch_lookup = torch.tensor([[[0, 0], [2, 0]],
...                              [[1, 1], [0, 2]]])
>>> positions, mask = get_sparse_index_mapping(sparse, batch_lookup)
>>> positions
tensor([[0, 2],
        [1, 0]])
>>> mask
tensor([[True, True],
        [True, False]])
>>> # Out of bounds indices
>>> lookup = torch.tensor([[0, 0], [3, 0], [-1, 0]])  # 3 and -1 are out of bounds
>>> positions, mask = get_sparse_index_mapping(sparse, lookup)
>>> mask
tensor([True, False, False])  # Out of bounds treated as unspecified
Source code in pytorch_sparse_utils/indexing/utils.py
@torch.jit.script
def get_sparse_index_mapping(
    sparse_tensor: Tensor,
    index_tensor: Tensor,
    sanitize_linear_index_tensor: bool = True,
) -> tuple[Tensor, Tensor]:
    """Finds the locations along a sparse tensor's values tensor for specified
    sparse indices. Also returns a mask indicating which indices have values
    actually present in the sparse tensor. It works by flattening the sparse
    tensor's sparse dims and the index tensor to 1D (and converting n-d indices
    to raveled indices), then using searchsorted along the flattened sparse
    tensor indices.

    Args:
        sparse_tensor (Tensor): Sparse tensor of dimension ..., M; where ... are
            S leading sparse dimensions and M is the dense dimension.
        index_tensor (Tensor): Long tensor of dimension ..., S; where ... are
            leading batch dimensions. Negative indices and indices outside the
            bounds of the sparse dimensions are not supported and will
            be considered unspecified, with the corresponding entry in
            is_specified_mask being set to False.
        sanitize_linear_index_tensor (bool): If False, then the output values at
            linear_index_tensor[~is_specified_mask] will be the "insertion position"
            that would keep the sparse tensor's indices ordered. This is useful if you
            want to insert values, but means that
            sparse_tensor.values()[linear_index_tensor] will be potentially unsafe if
            some of the "insertion position" values are out of bounds. If this arg is
            True, linear_index_tensor[~is_specified_mask] values will be set to 0.
            Defaults to True.

    Returns:
        linear_index_tensor: Long tensor of dimension ... of the locations in
            sparse_tensor.values() corresponding to the indices in index_tensor.
            Elements where is_specified_mask is False are handled according to the
            value of sanitize_linear_index_tensor.
        is_specified_mask: Boolean tensor of dimension ... that is True for
            indices in index_tensor where values where actually specified in
            the sparse tensor and False for indices that were unspecified in
            the sparse tensor.

    Examples:
        >>> # Create a sparse tensor
        >>> indices = torch.tensor([[0, 1, 2], [0, 1, 0]])
        >>> values = torch.tensor([10.0, 20.0, 30.0])
        >>> sparse = torch.sparse_coo_tensor(indices, values, (3, 3))

        >>> # Look up existing indices
        >>> lookup = torch.tensor([[0, 0], [1, 1], [2, 0]])
        >>> positions, mask = get_sparse_index_mapping(sparse, lookup)
        >>> positions
        tensor([0, 1, 2])  # Positions in values tensor
        >>> mask
        tensor([True, True, True])  # All indices exist
        >>> sparse.values()[positions]
        tensor([10., 20., 30.])

        >>> # Look up mix of existing and non-existing indices
        >>> lookup = torch.tensor([[0, 0], [0, 1], [1, 0]])
        >>> positions, mask = get_sparse_index_mapping(sparse, lookup)
        >>> positions
        tensor([0, 0, 0])  # Non-existing indices mapped to 0 (sanitized)
        >>> mask
        tensor([True, False, False])  # Only first index exists

        >>> # With sanitize=False to get insertion positions
        >>> positions, mask = get_sparse_index_mapping(sparse, lookup,
        ...                                            sanitize_linear_index_tensor=False)
        >>> positions
        tensor([0, 1, 1])  # Position 1 is where [0,1] and [1,0] would be inserted
        >>> mask
        tensor([True, False, False])

        >>> # Batch lookup
        >>> batch_lookup = torch.tensor([[[0, 0], [2, 0]],
        ...                              [[1, 1], [0, 2]]])
        >>> positions, mask = get_sparse_index_mapping(sparse, batch_lookup)
        >>> positions
        tensor([[0, 2],
                [1, 0]])
        >>> mask
        tensor([[True, True],
                [True, False]])

        >>> # Out of bounds indices
        >>> lookup = torch.tensor([[0, 0], [3, 0], [-1, 0]])  # 3 and -1 are out of bounds
        >>> positions, mask = get_sparse_index_mapping(sparse, lookup)
        >>> mask
        tensor([True, False, False])  # Out of bounds treated as unspecified
    """
    sparse_dim = sparse_tensor.sparse_dim()
    sparse_nnz = sparse_tensor._nnz()
    sparse_tensor_shape = torch._shape_as_tensor(sparse_tensor).to(
        device=index_tensor.device
    )
    sparse_shape = sparse_tensor_shape[:sparse_dim]

    # check for empty sparse tensor
    if sparse_nnz == 0:
        linear_index_tensor = index_tensor.new_zeros(index_tensor.shape[:-1])
        is_specified_mask = index_tensor.new_zeros(
            index_tensor.shape[:-1], dtype=torch.bool
        )
        return linear_index_tensor, is_specified_mask

    # Check for out of bounds indices (below 0 or outside tensor dim)
    out_of_bounds_indices = torch.any(index_tensor < 0, -1)
    out_of_bounds_indices |= torch.any(index_tensor >= sparse_shape, -1)

    # put dummy value of 0 in the OOB indices.
    # Maybe it'll make the linearization computations and searchsorted faster:
    # a compromise between just giving searchsorted random indices to find vs
    # causing a cpu sync to call nonzeros to filter them out
    index_tensor = index_tensor.masked_fill(out_of_bounds_indices.unsqueeze(-1), 0)
    (
        sparse_tensor_indices_linearized,
        index_tensor_linearized,
    ) = linearize_sparse_and_index_tensors(sparse_tensor, index_tensor)

    # The dummy value of 0 should always return searched index of 0 since
    # the sparse_tensor_indices_linearized values are always nonnegative.
    # Should be faster to find than random search values.
    linear_index_tensor = torch.searchsorted(  # binary search
        sparse_tensor_indices_linearized, index_tensor_linearized
    )

    # linear_index_tensor is distinct from index_tensor_linearized in that
    # index_tensor_linearized has the flattened version of the index in the sparse
    # tensor, while linear_index_tensor has the corresponding index in the sparse
    # tensor's values() tensor

    # guard against IndexError
    if sanitize_linear_index_tensor:
        index_clamped = linear_index_tensor.clamp_max_(sparse_nnz - 1)
    else:
        index_clamped = linear_index_tensor.clamp_max(sparse_nnz - 1)

    # Check if the indices were specified by checking for an exact match at the
    # resultant searched indices
    is_specified_mask: Tensor = (
        sparse_tensor_indices_linearized[index_clamped] == index_tensor_linearized
    )
    is_specified_mask &= ~out_of_bounds_indices.view(-1)

    linear_index_tensor = linear_index_tensor.view(index_tensor.shape[:-1])
    is_specified_mask = is_specified_mask.view(index_tensor.shape[:-1])

    return linear_index_tensor, is_specified_mask

gather_mask_and_fill(values, indices, mask, fill=None)

Efficiently gathers elements from an ND tensor, applies a mask, and fills masked positions.

This function performs the equivalent of out = values[indices] out[~mask] = fill.expand_as(out)[~mask] # or 0 but uses torch.index_select for better performance. It retrieves values at the specified indices and fills positions where the mask is False with either zeros (default) or the provided fill values.

Parameters:
  • values (Tensor) –

    Source tensor to gather from, may be 1D with shape (N) or n-D with shape (N, D0, D1, ...), where N is the number of elements and D are potentially multiple feature dimensions.

  • indices (Tensor) –

    Long tensor of indices into the first dimension of values. Can be of any shape.

  • mask (Tensor) –

    Boolean tensor with the same shape as indices. True indicates positions to keep, False indicates positions to zero out.

  • fill (Optional[Tensor], default: None ) –

    A tensor that must be broadcast-compatible with the final output shape. It is inserted at positions where mask is False. When None (default), a zero tensor is used.

Returns:
  • Tensor( Tensor ) –

    The gathered and masked values with shape (indices.shape, values.shape[-1]). Contains values from the source tensor at the specified indices, with masked positions filled with zeros or from fill.

Raises:
  • ValueError

    If indices and mask have different shapes.

Examples:

>>> # Basic usage: gather and mask 1D values
>>> values = torch.tensor([10.0, 20.0, 30.0, 40.0])
>>> indices = torch.tensor([0, 2, 1, 3])
>>> mask = torch.tensor([True, True, False, True])
>>> gather_mask_and_fill(values, indices, mask)
tensor([10., 30.,  0., 40.])
>>> # Multi-dimensional values
>>> values = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
>>> indices = torch.tensor([2, 0, 1])
>>> mask = torch.tensor([True, False, True])
>>> gather_mask_and_fill(values, indices, mask)
tensor([[5., 6.],
        [0., 0.],
        [3., 4.]])
>>> # 2D indices with custom fill
>>> values = torch.tensor([10.0, 20.0, 30.0])
>>> indices = torch.tensor([[0, 1], [2, 0]])
>>> mask = torch.tensor([[True, False], [True, True]])
>>> gather_mask_and_fill(values, indices, mask, fill=torch.tensor(-1.0))
tensor([[10., -1.],
        [30., 10.]])
Source code in pytorch_sparse_utils/indexing/utils.py
@torch.jit.script
def gather_mask_and_fill(
    values: Tensor, indices: Tensor, mask: Tensor, fill: Optional[Tensor] = None
) -> Tensor:
    """Efficiently gathers elements from an ND tensor, applies a mask, and fills masked
    positions.

    This function performs the equivalent of
    `out = values[indices]
    out[~mask] = fill.expand_as(out)[~mask]  # or 0
    `
    but uses torch.index_select for better performance. It retrieves values at the
    specified indices and fills positions where the mask is False with either zeros
    (default) or the provided fill values.

    Args:
        values (Tensor): Source tensor to gather from, may be 1D with shape (N)
            or n-D with shape (N, D0, D1, ...), where N is the number of elements
            and D are potentially multiple feature dimensions.
        indices (Tensor): Long tensor of indices into the first dimension of values.
            Can be of any shape.
        mask (Tensor): Boolean tensor with the same shape as indices. True indicates
            positions to keep, False indicates positions to zero out.
        fill (Optional[Tensor]): A tensor that must be broadcast-compatible with the
            final output shape. It is inserted at positions where `mask` is False.
            When None (default), a zero tensor is used.

    Returns:
        Tensor: The gathered and masked values with shape
            (*indices.shape, *values.shape[-1]). Contains values from the source tensor
            at the specified indices, with masked positions filled with zeros or from
            `fill`.

    Raises:
        ValueError: If indices and mask have different shapes.

    Examples:
        >>> # Basic usage: gather and mask 1D values
        >>> values = torch.tensor([10.0, 20.0, 30.0, 40.0])
        >>> indices = torch.tensor([0, 2, 1, 3])
        >>> mask = torch.tensor([True, True, False, True])
        >>> gather_mask_and_fill(values, indices, mask)
        tensor([10., 30.,  0., 40.])

        >>> # Multi-dimensional values
        >>> values = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
        >>> indices = torch.tensor([2, 0, 1])
        >>> mask = torch.tensor([True, False, True])
        >>> gather_mask_and_fill(values, indices, mask)
        tensor([[5., 6.],
                [0., 0.],
                [3., 4.]])

        >>> # 2D indices with custom fill
        >>> values = torch.tensor([10.0, 20.0, 30.0])
        >>> indices = torch.tensor([[0, 1], [2, 0]])
        >>> mask = torch.tensor([[True, False], [True, True]])
        >>> gather_mask_and_fill(values, indices, mask, fill=torch.tensor(-1.0))
        tensor([[10., -1.],
                [30., 10.]])
    """
    input_values_1d = False
    if values.ndim == 1:
        input_values_1d = True
        values = values.unsqueeze(1)

    if indices.shape != mask.shape:
        raise ValueError(
            "Expected indices and mask to have same shape, got "
            f"{indices.shape} and {mask.shape}"
        )

    indices_flat = indices.reshape(-1)
    mask_flat = mask.reshape(-1)

    # figure out how much to broadcast
    value_dims = values.shape[1:]
    n_value_dims = values.ndim - 1

    new_shape = indices.shape
    if not input_values_1d:
        new_shape += value_dims

    # pre-mask the indices to guard against unsafe indices in the masked portion
    indices_flat = torch.where(mask_flat, indices_flat, torch.zeros_like(indices_flat))

    selected = values.index_select(0, indices_flat)

    # unsqueeze mask
    mask_flat = mask_flat.view((mask_flat.size(0),) + (1,) * n_value_dims)

    if fill is None:
        selected.masked_fill_(~mask_flat, 0)
    else:
        fill_broadcast = fill.expand(new_shape).reshape(selected.shape)
        selected = torch.where(mask_flat, selected, fill_broadcast)

    selected = selected.reshape(new_shape)
    return selected