Indexing utilties
The indexing module provides operations for selecting, slicing, and gather/scattering sparse tensors.
Basic Operations
sparse_select(tensor, axis, index)
Selects a single subtensor from a sparse tensor along the specified axis.
This function extracts a slice of a sparse tensor by selecting elements where the coordinate along the specified sparse axis matches the given index value. Unlike Pytorch's built-in indexing, this implementation properly handles gradient flow for sparse tensors.
Parameters: |
|
---|
Returns: |
|
---|
Raises: |
|
---|
Examples:
Create a sparse tensor and select slices:
>>> # Create a sparse tensor with dimensions (3, 4, 2), where the last
>>> # dimension is a dense dimension
>>> i = torch.tensor([[0, 1], [1, 3], [2, 2]].T)
>>> v = torch.tensor([[1.0, 1.5], [2.0, 2.5], [3.0, 3.5]])
>>> x = torch.sparse_coo_tensor(i, v, (3, 4, 2))
>>>
>>> # Select values along sparse dimension 0
>>> slice_0 = sparse_select(x, 0, 1) # Get elements where dim 0 == 1
>>> # slice_0 has shape (4, 2)
>>>
>>> # Select values along dense dimension 2
>>> slice_2 = sparse_select(x, 2, 0) # Get elements where dim 2 == 0
>>> # slice_2 has shape (3, 4)
Source code in pytorch_sparse_utils/indexing/basics.py
@torch.jit.script
def sparse_select(tensor: Tensor, axis: int, index: int) -> Tensor:
"""Selects a single subtensor from a sparse tensor along the specified axis.
This function extracts a slice of a sparse tensor by selecting elements
where the coordinate along the specified sparse axis matches the given index value.
Unlike Pytorch's built-in indexing, this implementation properly handles
gradient flow for sparse tensors.
Args:
tensor (Tensor): The input sparse tensor.
axis (int): The dimension along which to select the subtensor. Negative axes
are supported.
index (int): The index to select along the specified axis. Negative indices
are supported.
Returns:
Tensor: A new sparse tensor with one fewer dimension than the input tensor.
The shape will be tensor.shape[:axis] + tensor.shape[axis+1:].
Raises:
ValueError: If the input tensor is not sparse, or if the axis or index are
out of bounds.
Examples:
Create a sparse tensor and select slices:
>>> # Create a sparse tensor with dimensions (3, 4, 2), where the last
>>> # dimension is a dense dimension
>>> i = torch.tensor([[0, 1], [1, 3], [2, 2]].T)
>>> v = torch.tensor([[1.0, 1.5], [2.0, 2.5], [3.0, 3.5]])
>>> x = torch.sparse_coo_tensor(i, v, (3, 4, 2))
>>>
>>> # Select values along sparse dimension 0
>>> slice_0 = sparse_select(x, 0, 1) # Get elements where dim 0 == 1
>>> # slice_0 has shape (4, 2)
>>>
>>> # Select values along dense dimension 2
>>> slice_2 = sparse_select(x, 2, 0) # Get elements where dim 2 == 0
>>> # slice_2 has shape (3, 4)
"""
if not tensor.is_sparse:
raise ValueError("Input tensor is not sparse.")
if not isinstance(index, int):
raise ValueError(f"Expected integer index, got type {type(index)}")
# Normalize negative axis
orig_axis = axis
if axis < 0:
axis = tensor.ndim + axis
# Validate axis
if axis < 0 or axis >= tensor.ndim:
raise ValueError(
f"Axis {orig_axis} out of bounds for tensor with {tensor.ndim} dimensions"
)
# Normalize negative index
orig_index = index
if index < 0:
index = tensor.size(axis) + index
# check if index is in bounds
if index < 0 or index >= tensor.size(axis):
# weird string construction to work around torchscript not liking concatting
# multiple f strings
error_str = "Index " + str(orig_index) + " is out of bounds on axis "
error_str += (
str(orig_axis) + " for tensor with shape " + str(tensor.shape) + "."
)
raise ValueError(error_str)
tensor = tensor.coalesce()
sparse_dims = tensor.sparse_dim()
if axis < sparse_dims:
# Selection along a sparse dimension
index_mask = tensor.indices()[axis] == index
values = tensor.values()[index_mask]
indices = torch.cat(
[
tensor.indices()[:axis, index_mask],
tensor.indices()[axis + 1 :, index_mask],
]
)
else:
# Selecting along a dense dimension
# This means we just index the values tensor along the appropriate dim
# and the sparse indices stay the same
indices = tensor.indices()
dense_axis = axis - sparse_dims + 1 # +1 because first dim of values is nnz
values = tensor.values().select(dense_axis, index)
return torch.sparse_coo_tensor(
indices, values, tensor.shape[:axis] + tensor.shape[axis + 1 :]
).coalesce()
sparse_index_select(tensor, axis, index, check_bounds=True, disable_builtin_fallback=False)
Selects values from a sparse tensor along a specified dimension.
This function is equivalent to Tensor.index_select(axis, index) but offers full support for the backward pass for sparse tensors. It returns a new sparse tensor containing only the values at the specified indices along the given axis.
This function falls back to the built-in Tensor.index_select(axis, index) when gradients are not required. Benchmarking seems to indicate the built-in version is generally faster and more memory efficient except for some specialized situations on CUDA. You can always use the custom implementation by setting this function's input argument disable_builtin_fallback to True.
Note that the built-in Tensor.index_select will trigger Device-Side Assert errors if it is given indices outside the bounds of a sparse tensor. Unlike the built-in Tensor.index_select, this function validates that indices are within bounds (when check_bounds=True), making it a safer alternative even when gradient support isn't needed.
Parameters: |
|
---|
Returns: |
|
---|
Raises: |
|
---|
Examples:
>>> # Create a sparse tensor with shape (4, 5)
>>> indices = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 4]])
>>> values = torch.tensor([10.0, 20.0, 30.0, 40.0])
>>> sparse_tensor = torch.sparse_coo_tensor(indices, values, (4, 5))
>>> # Select rows 1 and 3
>>> selected = sparse_index_select(sparse_tensor, axis=0, index=torch.tensor([1, 3]))
>>> selected.to_dense()
tensor([[ 0., 20., 0., 0., 0.],
[ 0., 0., 0., 0., 40.]])
>>> # Select columns 2, 3, and 4
>>> selected = sparse_index_select(sparse_tensor, axis=1, index=torch.tensor([2, 3, 4]))
>>> selected.to_dense()
tensor([[ 0., 0., 0.],
[20., 0., 0.],
[ 0., 30., 0.],
[ 0., 0., 40.]])
>>> # Works with negative axis indexing
>>> selected = sparse_index_select(sparse_tensor, axis=-1, index=torch.tensor([4, 2]))
>>> selected.to_dense()
tensor([[ 0., 0.],
[ 0., 20.],
[ 0., 0.],
[40., 30.]])
>>> # Gradient support example
>>> values = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
>>> indices = torch.tensor([[0, 1, 2], [0, 1, 0]])
>>> x = torch.sparse_coo_tensor(indices, values, (3, 2))
>>> selected = sparse_index_select(x, axis=0, index=torch.tensor([0, 2]))
>>> loss = selected.sum()
>>> loss.backward()
>>> values.grad # Gradient flows back to original values
tensor([1., 0., 1.])
Source code in pytorch_sparse_utils/indexing/basics.py
@torch.jit.script
def sparse_index_select(
tensor: Tensor,
axis: int,
index: Tensor,
check_bounds: bool = True,
disable_builtin_fallback: bool = False,
) -> Tensor:
"""Selects values from a sparse tensor along a specified dimension.
This function is equivalent to Tensor.index_select(axis, index) but offers full
support for the backward pass for sparse tensors. It returns a new sparse
tensor containing only the values at the specified indices along the given axis.
This function falls back to the built-in Tensor.index_select(axis, index)
when gradients are not required. Benchmarking seems to indicate the built-in
version is generally faster and more memory efficient except for some specialized
situations on CUDA. You can always use the custom implementation by setting this
function's input argument disable_builtin_fallback to True.
Note that the built-in Tensor.index_select will trigger Device-Side Assert errors
if it is given indices outside the bounds of a sparse tensor.
Unlike the built-in Tensor.index_select, this function validates that indices
are within bounds (when check_bounds=True), making it a safer alternative even
when gradient support isn't needed.
Args:
tensor (Tensor): The input sparse tensor from which to select values.
axis (int): The dimension along which to select values. Can be negative
to index from the end.
index (Tensor): The indices of the values to select along the specified
dimension. Must be a 1D tensor or scalar of integer dtype.
check_bounds (bool, optional): Whether to check if indices are within bounds.
Set to False if indices are guaranteed to be in-bounds to avoid a CPU sync
on CUDA tensors. Benchmarking shows the bounds check leads to an overhead
of about 5% on cpu and 10% on cuda. Defaults to True.
disable_builtin_fallback (bool, optional): Whether to always use the custom
gradient-tracking version of index_select, even when gradients are not
needed. Does nothing if the input sparse tensor does require gradients.
Defaults to False.
Returns:
Tensor: A new sparse tensor containing the selected values.
Raises:
ValueError:
- If the input tensor is not sparse.
- If the index tensor has invalid shape or is not an integer tensor.
- If the axis is out of bounds for tensor dimensions.
- If check_bounds is True and the index tensor contains out-of-bounds
indices.
Examples:
>>> # Create a sparse tensor with shape (4, 5)
>>> indices = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 4]])
>>> values = torch.tensor([10.0, 20.0, 30.0, 40.0])
>>> sparse_tensor = torch.sparse_coo_tensor(indices, values, (4, 5))
>>> # Select rows 1 and 3
>>> selected = sparse_index_select(sparse_tensor, axis=0, index=torch.tensor([1, 3]))
>>> selected.to_dense()
tensor([[ 0., 20., 0., 0., 0.],
[ 0., 0., 0., 0., 40.]])
>>> # Select columns 2, 3, and 4
>>> selected = sparse_index_select(sparse_tensor, axis=1, index=torch.tensor([2, 3, 4]))
>>> selected.to_dense()
tensor([[ 0., 0., 0.],
[20., 0., 0.],
[ 0., 30., 0.],
[ 0., 0., 40.]])
>>> # Works with negative axis indexing
>>> selected = sparse_index_select(sparse_tensor, axis=-1, index=torch.tensor([4, 2]))
>>> selected.to_dense()
tensor([[ 0., 0.],
[ 0., 20.],
[ 0., 0.],
[40., 30.]])
>>> # Gradient support example
>>> values = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
>>> indices = torch.tensor([[0, 1, 2], [0, 1, 0]])
>>> x = torch.sparse_coo_tensor(indices, values, (3, 2))
>>> selected = sparse_index_select(x, axis=0, index=torch.tensor([0, 2]))
>>> loss = selected.sum()
>>> loss.backward()
>>> values.grad # Gradient flows back to original values
tensor([1., 0., 1.])
"""
if not tensor.is_sparse:
raise ValueError("Input tensor must be sparse")
# Validate index tensor
if torch.is_floating_point(index):
raise ValueError(f"Received index tensor of non-integer dtype: {index.dtype}")
if index.ndim > 1:
raise ValueError(f"Index tensor must be 0D or 1D, got {index.ndim}D")
elif index.ndim == 0:
index = index.unsqueeze(0)
# Normalize negative axis
orig_axis = axis
if axis < 0:
axis = tensor.ndim + axis
# Validate axis
if axis < 0 or axis >= tensor.ndim:
raise ValueError(
f"Axis {orig_axis} out of bounds for tensor with {tensor.ndim} dimensions"
)
# Validate index bounds (optional)
if check_bounds and index.numel() > 0:
out_of_bounds = ((index < 0) | (index >= tensor.shape[axis])).any()
if out_of_bounds: # cpu sync happens here
raise ValueError(
f"Index tensor has entries out of bounds for axis {orig_axis} with size {tensor.shape[axis]}"
)
if not tensor.requires_grad and not disable_builtin_fallback:
# Fall back to built-in implementation
return tensor.index_select(axis, index.long()).coalesce()
tensor = tensor.coalesce()
tensor_indices = tensor.indices()
tensor_values = tensor.values()
new_indices, new_values = _sparse_index_select_inner(
tensor_indices, tensor_values, axis, index
)
new_shape = list(tensor.shape)
new_shape[axis] = len(index)
return torch.sparse_coo_tensor(new_indices.long(), new_values, new_shape).coalesce()
Bulk Indexing
batch_sparse_index(sparse_tensor, index_tensor, check_all_specified=False)
Batch selection of elements from a torch sparse tensor. The index tensor may have arbitrary batch dimensions.
If dense_tensor = sparse_tensor.to_sparse(), the equivalent dense indexing operation would be dense_tensor[index_tensor.unbind(-1)]
Parameters: |
|
---|
Returns: |
|
---|
Raises: |
|
---|
Examples:
>>> # Create a 3D sparse tensor with 2 sparse dims and 1 dense dim
>>> indices = torch.tensor([[0, 1, 2], [0, 1, 2]])
>>> values = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
>>> sparse = torch.sparse_coo_tensor(indices, values, (3, 3, 2))
>>> # Single index lookup
>>> index = torch.tensor([[1, 1]])
>>> result, mask = batch_sparse_index(sparse, index)
>>> result
tensor([[3., 4.]])
>>> mask
tensor([True])
>>> # Batch lookup with some unspecified indices
>>> batch_indices = torch.tensor([[0, 0], [1, 1], [2, 0]]) # Last one unspecified
>>> results, masks = batch_sparse_index(sparse, batch_indices)
>>> results
tensor([[1., 2.],
[3., 4.],
[0., 0.]]) # Zeros for unspecified index
>>> masks
tensor([ True, True, False])
>>> # 2D batch of indices
>>> batch_indices = torch.tensor([[[0, 0], [2, 2]],
... [[1, 1], [0, 1]]])
>>> results, masks = batch_sparse_index(sparse, batch_indices)
>>> results.shape
torch.Size([2, 2, 2])
>>> masks
tensor([[ True, True],
[ True, False]])
>>> # Check all specified - will raise error if any index is missing
>>> try:
... results, masks = batch_sparse_index(sparse, torch.tensor([[0, 2]]),
... check_all_specified=True)
... except ValueError as e:
... print("Error:", e)
Error: `check_all_specified` was set to True but not all gathered values were specified
Source code in pytorch_sparse_utils/indexing/basics.py
@torch.jit.script
def batch_sparse_index(
sparse_tensor: Tensor, index_tensor: Tensor, check_all_specified: bool = False
) -> tuple[Tensor, Tensor]:
"""Batch selection of elements from a torch sparse tensor. The index tensor may
have arbitrary batch dimensions.
If dense_tensor = sparse_tensor.to_sparse(), the equivalent dense indexing
operation would be dense_tensor[index_tensor.unbind(-1)]
Args:
sparse_tensor (Tensor): Sparse tensor of dimension
[S0, S1, ..., Sn-1, D0, D1, ..., Dm-1], with n sparse dimensions and
m dense dimensions.
index_tensor (Tensor): Long tensor of dimension [B0, B1, ..., Bp-1, n]
with optional p leading batch dimensions and final dimension corresponding
to the sparse dimensions of the sparse tensor. Negative indices are not
supported and will be considered unspecified.
check_all_specified (bool): If True, this function will raise a
ValueError if any of the indices in `index_tensor` are not specified
in `sparse_tensor`. If False, selections at unspecified indices will be
returned with padding values of 0. Defaults to False.
Returns:
Tensor: Tensor of dimension [B0, B1, ..., Bp-1, D0, D1, ..., Dm-1].
Tensor: Boolean tensor of dimension [B0, B1, ..., Bp-1], where each element is
True if the corresponding index is a specified (nonzero) element of the
sparse tensor and False if not.
Raises:
ValueError: If `check_all_specified` is set to True and not all indices in
`index_tensor` had associated values specified in `sparse_tensor`, or if
`index_tensor` is a nested tensor (feature planned but not implemented yet)
Examples:
>>> # Create a 3D sparse tensor with 2 sparse dims and 1 dense dim
>>> indices = torch.tensor([[0, 1, 2], [0, 1, 2]])
>>> values = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
>>> sparse = torch.sparse_coo_tensor(indices, values, (3, 3, 2))
>>> # Single index lookup
>>> index = torch.tensor([[1, 1]])
>>> result, mask = batch_sparse_index(sparse, index)
>>> result
tensor([[3., 4.]])
>>> mask
tensor([True])
>>> # Batch lookup with some unspecified indices
>>> batch_indices = torch.tensor([[0, 0], [1, 1], [2, 0]]) # Last one unspecified
>>> results, masks = batch_sparse_index(sparse, batch_indices)
>>> results
tensor([[1., 2.],
[3., 4.],
[0., 0.]]) # Zeros for unspecified index
>>> masks
tensor([ True, True, False])
>>> # 2D batch of indices
>>> batch_indices = torch.tensor([[[0, 0], [2, 2]],
... [[1, 1], [0, 1]]])
>>> results, masks = batch_sparse_index(sparse, batch_indices)
>>> results.shape
torch.Size([2, 2, 2])
>>> masks
tensor([[ True, True],
[ True, False]])
>>> # Check all specified - will raise error if any index is missing
>>> try:
... results, masks = batch_sparse_index(sparse, torch.tensor([[0, 2]]),
... check_all_specified=True)
... except ValueError as e:
... print("Error:", e)
Error: `check_all_specified` was set to True but not all gathered values were specified
"""
if index_tensor.is_nested:
raise ValueError("Nested index tensor not supported yet")
sparse_tensor = sparse_tensor.coalesce()
sparse_tensor_values = sparse_tensor.values()
dense_dim = sparse_tensor.dense_dim()
index_search, is_specified_mask = get_sparse_index_mapping(
sparse_tensor, index_tensor
)
if check_all_specified and not is_specified_mask.all():
raise ValueError(
"`check_all_specified` was set to True but not all gathered values "
"were specified"
)
selected: Tensor = gather_mask_and_fill(
sparse_tensor_values, index_search, is_specified_mask
)
out_shape = index_tensor.shape[:-1]
assert is_specified_mask.shape == out_shape
if dense_dim > 0:
out_shape = out_shape + (sparse_tensor.shape[-dense_dim:])
assert selected.shape == out_shape
return selected, is_specified_mask
Sparse Tensor Scatter
scatter_to_sparse_tensor(sparse_tensor, index_tensor, values, check_all_specified=False)
Batch updating of elements in a torch sparse tensor. Should be equivalent to sparse_tensor[index_tensor] = values. It works by flattening the sparse tensor's sparse dims and the index tensor to 1D (and converting n-d indices to raveled indices), then using index_copy along the flattened sparse tensor.
Parameters: |
|
---|
Returns: |
|
---|
Notes
This function uses index_copy as the underlying mechanism to write new values, so duplicate indices in index_tensor will have the same result as other uses of index_copy, i.e., the result will depend on which copy occurs last. This imitates the behavior of scatter-like operations rather than the typical coalescing deduplication behavior of sparse tensors.
Examples:
>>> # Create a sparse tensor with values
>>> indices = torch.tensor([[0, 1, 2], [0, 1, 0]])
>>> values = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
>>> sparse = torch.sparse_coo_tensor(indices, values, (3, 3, 2))
>>> # Update existing values
>>> update_indices = torch.tensor([[0, 0], [1, 1]])
>>> new_values = torch.tensor([[10.0, 20.0], [30.0, 40.0]])
>>> updated = scatter_to_sparse_tensor(sparse, update_indices, new_values)
>>> updated.to_dense()
tensor([[[10., 20.], # Updated
[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[30., 40.], # Updated
[ 0., 0.]],
[[5., 6.], # Unchanged
[ 0., 0.],
[ 0., 0.]]])
>>> # Add new values (scatter to unspecified locations)
>>> new_indices = torch.tensor([[0, 2], [2, 2]])
>>> new_values = torch.tensor([[100.0, 200.0], [300.0, 400.0]])
>>> updated = scatter_to_sparse_tensor(sparse, new_indices, new_values)
>>> updated.to_dense()[0, 2] # New value added
tensor([100., 200.])
>>> # Batch update with multiple indices
>>> batch_indices = torch.tensor([[[0, 0], [1, 1]],
... [[2, 0], [0, 1]]])
>>> batch_values = torch.tensor([[[11., 12.], [13., 14.]],
... [[15., 16.], [17., 18.]]])
>>> # Flatten batch dimensions
>>> flat_indices = batch_indices.reshape(-1, 2)
>>> flat_values = batch_values.reshape(-1, 2)
>>> updated = scatter_to_sparse_tensor(sparse, flat_indices, flat_values)
>>> # check_all_specified example
>>> indices = torch.tensor([[0, 0], [1, 1]])
>>> values = torch.tensor([1.0, 2.0])
>>> sparse = torch.sparse_coo_tensor(indices.T, values, (2, 2))
>>>
>>> # This will succeed (all indices exist)
>>> update_indices = torch.tensor([[0, 0]])
>>> update_values = torch.tensor([10.0])
>>> result = scatter_to_sparse_tensor(sparse, update_indices, update_values,
... check_all_specified=True)
>>>
>>> # This will raise ValueError (index [1, 0] doesn't exist)
>>> try:
... bad_indices = torch.tensor([[1, 0]])
... bad_values = torch.tensor([20.0])
... result = scatter_to_sparse_tensor(sparse, bad_indices, bad_values,
... check_all_specified=True)
... except ValueError as e:
... print("Error:", e)
Error: `check_all_specified` was set to True but not all gathered values were specified
Source code in pytorch_sparse_utils/indexing/scatter.py
def scatter_to_sparse_tensor(
sparse_tensor: Tensor,
index_tensor: Tensor,
values: Tensor,
check_all_specified: bool = False,
) -> Tensor:
"""Batch updating of elements in a torch sparse tensor. Should be
equivalent to sparse_tensor[index_tensor] = values. It works by flattening
the sparse tensor's sparse dims and the index tensor to 1D (and converting
n-d indices to raveled indices), then using index_copy along the flattened
sparse tensor.
Args:
sparse_tensor (Tensor): Sparse tensor of dimension
[s0, s1, s2, ..., d0, d1, d2, ...]; where s0, s1, ... are
S leading sparse dimensions and d0, d1, d2, ... are D dense dimensions.
index_tensor (Tensor): Long tensor of dimension [b0, b1, b2, ..., S]; where
b0, b1, b2, ... are B leading batch dimensions.
values (Tensor): Tensor of dimension [b0, b1, b2, ... d0, d1, d2, ...], where
dimensions are as above.
check_all_specified (bool): If True, this function will throw a ValueError
if any of the indices specified in index_tensor are not already present
in sparse_tensor. Default: False.
Returns:
Tensor: sparse_tensor with the new values scattered into it
Notes:
This function uses index_copy as the underlying mechanism to write new values,
so duplicate indices in index_tensor will have the same result as other
uses of index_copy, i.e., the result will depend on which copy occurs last.
This imitates the behavior of scatter-like operations rather than the
typical coalescing deduplication behavior of sparse tensors.
Examples:
>>> # Create a sparse tensor with values
>>> indices = torch.tensor([[0, 1, 2], [0, 1, 0]])
>>> values = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
>>> sparse = torch.sparse_coo_tensor(indices, values, (3, 3, 2))
>>> # Update existing values
>>> update_indices = torch.tensor([[0, 0], [1, 1]])
>>> new_values = torch.tensor([[10.0, 20.0], [30.0, 40.0]])
>>> updated = scatter_to_sparse_tensor(sparse, update_indices, new_values)
>>> updated.to_dense()
tensor([[[10., 20.], # Updated
[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[30., 40.], # Updated
[ 0., 0.]],
[[5., 6.], # Unchanged
[ 0., 0.],
[ 0., 0.]]])
>>> # Add new values (scatter to unspecified locations)
>>> new_indices = torch.tensor([[0, 2], [2, 2]])
>>> new_values = torch.tensor([[100.0, 200.0], [300.0, 400.0]])
>>> updated = scatter_to_sparse_tensor(sparse, new_indices, new_values)
>>> updated.to_dense()[0, 2] # New value added
tensor([100., 200.])
>>> # Batch update with multiple indices
>>> batch_indices = torch.tensor([[[0, 0], [1, 1]],
... [[2, 0], [0, 1]]])
>>> batch_values = torch.tensor([[[11., 12.], [13., 14.]],
... [[15., 16.], [17., 18.]]])
>>> # Flatten batch dimensions
>>> flat_indices = batch_indices.reshape(-1, 2)
>>> flat_values = batch_values.reshape(-1, 2)
>>> updated = scatter_to_sparse_tensor(sparse, flat_indices, flat_values)
>>> # check_all_specified example
>>> indices = torch.tensor([[0, 0], [1, 1]])
>>> values = torch.tensor([1.0, 2.0])
>>> sparse = torch.sparse_coo_tensor(indices.T, values, (2, 2))
>>>
>>> # This will succeed (all indices exist)
>>> update_indices = torch.tensor([[0, 0]])
>>> update_values = torch.tensor([10.0])
>>> result = scatter_to_sparse_tensor(sparse, update_indices, update_values,
... check_all_specified=True)
>>>
>>> # This will raise ValueError (index [1, 0] doesn't exist)
>>> try:
... bad_indices = torch.tensor([[1, 0]])
... bad_values = torch.tensor([20.0])
... result = scatter_to_sparse_tensor(sparse, bad_indices, bad_values,
... check_all_specified=True)
... except ValueError as e:
... print("Error:", e)
Error: `check_all_specified` was set to True but not all gathered values were specified
"""
if index_tensor.is_nested:
assert values.is_nested
index_tensor = torch.cat(index_tensor.unbind())
values = torch.cat(values.unbind())
dense_dim = sparse_tensor.dense_dim()
sparse_dim = sparse_tensor.sparse_dim()
values_batch_dims = values.shape[:-dense_dim] if dense_dim else values.shape
if index_tensor.shape[:-1] != values_batch_dims:
raise ValueError(
"Expected matching batch dims for `index_tensor` and `values`, but got "
f"batch dims {index_tensor.shape[:-1]} and "
f"{values_batch_dims}, respectively."
)
sparse_tensor = sparse_tensor.coalesce()
sparse_tensor_values = sparse_tensor.values()
index_search, is_specified_mask = get_sparse_index_mapping(
sparse_tensor, index_tensor, sanitize_linear_index_tensor=False
)
all_specified = torch.all(is_specified_mask)
if check_all_specified and not all_specified:
raise ValueError(
"`check_all_specified` was set to True but not all gathered values "
"were specified"
)
# In-place update of existing values
if not sparse_tensor_values.requires_grad:
updated_values = sparse_tensor_values.index_copy_(
0, index_search[is_specified_mask], values[is_specified_mask]
)
else:
updated_values = sparse_tensor_values.index_copy(
0, index_search[is_specified_mask], values[is_specified_mask]
)
if all_specified: # No new values to append: tensor is fully updated
return torch.sparse_coo_tensor(
sparse_tensor.indices(),
updated_values,
sparse_tensor.shape,
dtype=sparse_tensor.dtype,
device=sparse_tensor.device,
is_coalesced=True,
)
# Need to append at least one new value: pre-sort the index tensor to save
# the final coalesce operation
# Pull out new values and indices to be added
new_insert_pos: Tensor = index_search[~is_specified_mask]
new_nd_indices = index_tensor[~is_specified_mask]
new_values = values[~is_specified_mask]
# Get sparse shape info for linearization
sparse_sizes = torch.tensor(
sparse_tensor.shape[:sparse_dim], device=sparse_tensor.device
)
if (new_nd_indices >= sparse_sizes.unsqueeze(0)).any():
raise ValueError(
"`index_tensor` has indices that are out of bounds of the original "
f"sparse tensor's sparse shape ({sparse_sizes})."
)
# Obtain linearized versions of all indices for sorting
old_indices_nd = sparse_tensor.indices()
linear_offsets = _make_linear_offsets(sparse_sizes)
new_indices_lin: Tensor = (new_nd_indices * linear_offsets).sum(-1)
# Find duplicate linear indices
unique_new_indices_lin, inverse = torch.unique(
new_indices_lin, sorted=True, return_inverse=True
)
# Use inverse of indices unique to write to new values tensor and tensor of
# insertion positions
deduped_new_values = new_values.new_empty(
(unique_new_indices_lin.size(0),) + new_values.shape[1:]
)
deduped_insert_pos = new_insert_pos.new_empty(unique_new_indices_lin.size(0))
# Deciding which duplicate value wins is offloaded to index_copy
deduped_new_values.index_copy_(0, inverse, new_values)
deduped_insert_pos.index_copy_(0, inverse, new_insert_pos)
# Convert uniqueified flattened indices to n-D for inclusion in indices tensor
unique_new_indices_nd = unflatten_nd_indices(
unique_new_indices_lin.unsqueeze(0), sparse_sizes, linear_offsets
)
# Concatenate old and new indices/values and sort
combined_indices, combined_values = _merge_sorted(
old_indices_nd,
unique_new_indices_nd,
updated_values,
deduped_new_values,
deduped_insert_pos,
)
return torch.sparse_coo_tensor(
combined_indices,
combined_values,
sparse_tensor.shape,
dtype=sparse_tensor.dtype,
device=sparse_tensor.device,
is_coalesced=True,
)
Miscellaneous Functions
unique_rows(tensor, sorted=True)
Returns the indices of the unique rows (first dimension) of a 2D integer tensor.
Parameters: |
|
---|
Returns: |
|
---|
Raises: |
|
---|
Examples:
>>> tensor = torch.tensor([[1, 2, 3],
... [4, 5, 6],
... [1, 2, 3], # Duplicate of row 0
... [7, 8, 9],
... [4, 5, 6]]) # Duplicate of row 1
>>> unique_indices = unique_rows(tensor)
>>> unique_indices
tensor([0, 1, 3])
>>> tensor[unique_indices] # All unique rows
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
Source code in pytorch_sparse_utils/indexing/unique.py
@torch.jit.script
def unique_rows(tensor: Tensor, sorted: bool = True) -> Tensor:
"""Returns the indices of the unique rows (first dimension) of a 2D integer tensor.
Args:
tensor (Tensor): A 2D tensor of integer type.
sorted (bool): Whether to sort the indices of unique rows before returning.
If False, returned indices will be in lexicographic order of the rows.
Returns:
Tensor: A 1D tensor whose elements are the indices of the unique rows of
the input tensor, i.e., if the return tensor is `inds`, then
tensor[inds] gives a 2D tensor of all unique rows of the input tensor.
Raises:
OverflowError: If the tensor has values that are large enough to cause overflow
errors when hashing each row to a single value.
Examples:
>>> tensor = torch.tensor([[1, 2, 3],
... [4, 5, 6],
... [1, 2, 3], # Duplicate of row 0
... [7, 8, 9],
... [4, 5, 6]]) # Duplicate of row 1
>>> unique_indices = unique_rows(tensor)
>>> unique_indices
tensor([0, 1, 3])
>>> tensor[unique_indices] # All unique rows
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
"""
if tensor.ndim != 2:
raise ValueError(f"Expected a 2D tensor, got ndim={tensor.ndim}")
if torch.is_floating_point(tensor) or torch.is_complex(tensor):
raise ValueError(f"Expected integer tensor, got dtype={tensor.dtype}")
max_vals = tensor.max(0).values
min_vals = tensor.min(0).values
# Check for overflow problems
INT64_MAX = 9223372036854775807
if (max_vals >= INT64_MAX).any():
raise OverflowError(
f"Tensor contains values at or near maximum int64 value ({INT64_MAX}), "
"which would lead to overflow errors when computing unique rows."
)
log_sum = (max_vals + 1).log().sum()
log_max = torch.tensor(INT64_MAX, device=max_vals.device).log()
if log_sum > log_max:
raise OverflowError(
"Hashing rows would cause integer overflow. Maximum hashed row product is "
f"approx {log_sum.exp()} compared to max int64 value of {INT64_MAX}."
)
# Handle negative values by shifting to nonnegative
has_negs = min_vals < 0
if has_negs.any():
# Shift each column to be nonnegative
neg_shift = torch.where(has_negs, min_vals, min_vals.new_zeros([]))
tensor = tensor - neg_shift
max_vals = max_vals - neg_shift
tensor_flat, _ = flatten_nd_indices(tensor.T.long(), max_vals)
tensor_flat: Tensor = tensor_flat.squeeze(0)
unique_flat_indices, unique_inverse = torch.unique(tensor_flat, return_inverse=True)
unique_row_indices: Tensor = unique_inverse.new_full(
(unique_flat_indices.size(0),), tensor_flat.size(0)
)
unique_row_indices.scatter_reduce_(
0,
unique_inverse,
torch.arange(tensor_flat.size(0), device=tensor.device),
"amin",
)
if sorted:
unique_row_indices = unique_row_indices.sort().values
return unique_row_indices
union_sparse_indices(sparse_tensor_1, sparse_tensor_2)
Creates unified sparse tensors with the union of indices from both input tensors.
This function takes two sparse tensors and returns versions of them that share the same set of indices (the union of indices from both inputs). For indices present in only one of the tensors, zeros are filled in for the corresponding values in the other tensor.
This function is useful for ensuring a one-to-one correspondence between two sparse tensors' respective values() tensors, which in turn may be useful for elementwise value comparisons like loss functions.
Parameters: |
|
---|
Returns: |
|
---|
Raises: |
|
---|
Note
For very large sparse tensors, this operation may require significant memory for intermediate tensors.
Source code in pytorch_sparse_utils/indexing/misc.py
@torch.jit.script
def union_sparse_indices(
sparse_tensor_1: Tensor, sparse_tensor_2: Tensor
) -> tuple[Tensor, Tensor]:
"""Creates unified sparse tensors with the union of indices from both input tensors.
This function takes two sparse tensors and returns versions of them that share the
same set of indices (the union of indices from both inputs). For indices present in
only one of the tensors, zeros are filled in for the corresponding values in the
other tensor.
This function is useful for ensuring a one-to-one correspondence between two
sparse tensors' respective values() tensors, which in turn may be useful for
elementwise value comparisons like loss functions.
Args:
sparse_tensor_1 (Tensor): First sparse tensor.
sparse_tensor_2 (Tensor): Second sparse tensor with the same sparse and dense
dimensions as sparse_tensor_1.
Returns:
tuple[Tensor, Tensor]: A tuple containing:
- tensor_1_unioned (Tensor): First tensor with indices expanded to include
all indices from the second tensor (with zeros for missing values).
- tensor_2_unioned (Tensor): Second tensor with indices expanded to include
all indices from the first tensor (with zeros for missing values).
Raises:
ValueError: If either input is not a sparse tensor or if the sparse and
dense dimensions don't match between tensors.
Note:
For very large sparse tensors, this operation may require significant memory
for intermediate tensors.
"""
if not sparse_tensor_1.is_sparse or not sparse_tensor_2.is_sparse:
raise ValueError(
"Expected two sparse tensors; got "
f"{__sparse_or_dense(sparse_tensor_1)} and {__sparse_or_dense(sparse_tensor_2)}"
)
if sparse_tensor_1.shape != sparse_tensor_2.shape:
raise ValueError(
"Expected tensors to have same shapes; got "
f"{sparse_tensor_1.shape} and {sparse_tensor_2.shape}"
)
if sparse_tensor_1.sparse_dim() != sparse_tensor_2.sparse_dim():
raise ValueError(
"Expected both sparse tensors to have equal numbers of sparse dims; got "
f"{sparse_tensor_1.sparse_dim()} and {sparse_tensor_2.sparse_dim()}"
)
if sparse_tensor_1.dense_dim() != sparse_tensor_2.dense_dim():
raise ValueError(
"Expected both sparse tensors to have equal numbers of dense dims; got "
f"{sparse_tensor_1.dense_dim()} and {sparse_tensor_2.dense_dim()}"
)
M = sparse_tensor_1.sparse_dim()
K = sparse_tensor_1.dense_dim()
sparse_tensor_1 = sparse_tensor_1.coalesce()
sparse_tensor_2 = sparse_tensor_2.coalesce()
indices_1, values_1 = sparse_tensor_1.indices(), sparse_tensor_1.values()
indices_2, values_2 = sparse_tensor_2.indices(), sparse_tensor_2.values()
# Need to find all indices that are unique to each sparse tensor
# To do this, stack one of them twice and the other once
indices_2_2_1 = torch.cat([indices_2, indices_2, indices_1], -1)
uniques, counts = torch.unique(indices_2_2_1, dim=-1, return_counts=True)
# Any that appear twice in the stacked indices are unique to tensor 2
# and any that appear once are unique to tensor 1
# (indices that appear 3x are shared already)
indices_only_in_tensor_1 = uniques[:, counts == 1]
indices_only_in_tensor_2 = uniques[:, counts == 2]
# Figure out how many new indices will be added to each sparse tensor
n_exclusives_1 = indices_only_in_tensor_1.size(-1)
n_exclusives_2 = indices_only_in_tensor_2.size(-1)
# Make zero-padding for new values tensors
pad_zeros_1 = values_1.new_zeros(
(n_exclusives_2,) + sparse_tensor_1.shape[M : M + K]
)
pad_zeros_2 = values_2.new_zeros(
(n_exclusives_1,) + sparse_tensor_1.shape[M : M + K]
)
# Make the new tensors by stacking indices and values together
tensor_1_unioned = torch.sparse_coo_tensor(
torch.cat([indices_1, indices_only_in_tensor_2], -1),
torch.cat([values_1, pad_zeros_1], 0),
size=sparse_tensor_1.shape,
device=sparse_tensor_1.device,
).coalesce()
tensor_2_unioned = torch.sparse_coo_tensor(
torch.cat([indices_2, indices_only_in_tensor_1], -1),
torch.cat([values_2, pad_zeros_2], 0),
size=sparse_tensor_2.shape,
device=sparse_tensor_2.device,
).coalesce()
if not torch.equal(tensor_1_unioned.indices(), tensor_2_unioned.indices()):
raise RuntimeError("Internal error: unioned tensors have different indices")
return tensor_1_unioned, tensor_2_unioned
Indexing Helpers
flatten_nd_indices(indices, sizes)
Flattens N-dimensional indices into 1-dimensional scalar indices.
Similar to np.ravel_multi_index but with slightly different API.
Parameters: |
|
---|
Returns: |
|
---|
Examples:
>>> # 2D indices to 1D
>>> indices = torch.tensor([[0, 1, 2], # row indices
... [0, 2, 1]]) # col indices
>>> sizes = torch.tensor([3, 3]) # 3x3 grid
>>> flat, offsets = flatten_nd_indices(indices, sizes)
>>> flat
tensor([[0, 5, 7]]) # 0*3+0=0, 1*3+2=5, 2*3+1=7
>>> offsets
tensor([3, 1])
>>> # 3D indices to 1D
>>> indices = torch.tensor([[0, 1], # dim 0
... [1, 0], # dim 1
... [2, 1]]) # dim 2
>>> sizes = torch.tensor([2, 2, 3]) # 2x2x3 tensor
>>> flat, offsets = flatten_nd_indices(indices, sizes)
>>> flat
tensor([[5, 7]]) # 0*6+1*3+2=5, 1*6+0*3+1=7
>>> offsets
tensor([6, 3, 1])
Source code in pytorch_sparse_utils/indexing/utils.py
@torch.jit.script
def flatten_nd_indices(indices: Tensor, sizes: Tensor) -> tuple[Tensor, Tensor]:
"""Flattens N-dimensional indices into 1-dimensional scalar indices.
Similar to np.ravel_multi_index but with slightly different API.
Args:
indices (Tensor): Integer coordinate tensor of shape [N, B], where N
is the number of dimensions to be flattened and B is the batch dimension.
sizes (Tensor): Extents of every dimension, of shape [N]
Returns:
flat_indices (Tensor): Flattened indices tensor, of shape [1, B]
offsets (Tensor): Strides that were used for flattening (needed for unflatten),
of shape [N]
Examples:
>>> # 2D indices to 1D
>>> indices = torch.tensor([[0, 1, 2], # row indices
... [0, 2, 1]]) # col indices
>>> sizes = torch.tensor([3, 3]) # 3x3 grid
>>> flat, offsets = flatten_nd_indices(indices, sizes)
>>> flat
tensor([[0, 5, 7]]) # 0*3+0=0, 1*3+2=5, 2*3+1=7
>>> offsets
tensor([3, 1])
>>> # 3D indices to 1D
>>> indices = torch.tensor([[0, 1], # dim 0
... [1, 0], # dim 1
... [2, 1]]) # dim 2
>>> sizes = torch.tensor([2, 2, 3]) # 2x2x3 tensor
>>> flat, offsets = flatten_nd_indices(indices, sizes)
>>> flat
tensor([[5, 7]]) # 0*6+1*3+2=5, 1*6+0*3+1=7
>>> offsets
tensor([6, 3, 1])
"""
offsets = _make_linear_offsets(sizes) # [N]
flat_indices = (indices * offsets.unsqueeze(-1)).sum(0, keepdim=True)
return flat_indices, offsets
unflatten_nd_indices(flat_indices, dim_sizes, offsets=None)
Reconstructs ('unflattens') N-D indices from 1D 'flattened' indices.
Parameters: |
|
---|
Returns: |
|
---|
Examples:
>>> # Unflatten 1D to 2D indices (inverse of flatten)
>>> flat_indices = torch.tensor([[0, 5, 7]])
>>> dim_sizes = torch.tensor([3, 3]) # 3x3 grid
>>> nd_indices = unflatten_nd_indices(flat_indices, dim_sizes)
>>> nd_indices
tensor([[0, 1, 2], # row indices
[0, 2, 1]]) # col indices
>>> # Unflatten to 3D indices
>>> flat_indices = torch.tensor([[5, 7, 23]])
>>> dim_sizes = torch.tensor([2, 3, 4]) # 2x3x4 tensor
>>> nd_indices = unflatten_nd_indices(flat_indices, dim_sizes)
>>> nd_indices
tensor([[0, 0, 1], # dim 0
[1, 1, 2], # dim 1
[1, 3, 3]]) # dim 2
>>> # Using precomputed offsets (more efficient for repeated calls)
>>> dim_sizes = torch.tensor([4, 5, 6])
>>> offsets = _make_linear_offsets(dim_sizes) # [30, 6, 1]
>>> flat_indices = torch.tensor([[0, 31, 119]])
>>> nd_indices = unflatten_nd_indices(flat_indices, dim_sizes, offsets)
>>> nd_indices
tensor([[0, 1, 3],
[0, 0, 4],
[0, 1, 5]])
Source code in pytorch_sparse_utils/indexing/utils.py
@torch.jit.script
def unflatten_nd_indices(
flat_indices: Tensor, dim_sizes: Tensor, offsets: Optional[Tensor] = None
) -> Tensor:
"""Reconstructs ('unflattens') N-D indices from 1D 'flattened' indices.
Args:
flat_indices (Tensor): Flat indices tensor of shape [1, B]
dim_sizes (Tensor): Original sizes of every dimension, of shape [N]
offsets (Optional[Tensor]): Offsets that were used for flattening, as returned
by _make_linear_offsets or flatten_nd_indices. If None, it will be
recalculated from `dim_sizes`
Returns:
Tensor: N-D indices tensor of shape [N, B]
Examples:
>>> # Unflatten 1D to 2D indices (inverse of flatten)
>>> flat_indices = torch.tensor([[0, 5, 7]])
>>> dim_sizes = torch.tensor([3, 3]) # 3x3 grid
>>> nd_indices = unflatten_nd_indices(flat_indices, dim_sizes)
>>> nd_indices
tensor([[0, 1, 2], # row indices
[0, 2, 1]]) # col indices
>>> # Unflatten to 3D indices
>>> flat_indices = torch.tensor([[5, 7, 23]])
>>> dim_sizes = torch.tensor([2, 3, 4]) # 2x3x4 tensor
>>> nd_indices = unflatten_nd_indices(flat_indices, dim_sizes)
>>> nd_indices
tensor([[0, 0, 1], # dim 0
[1, 1, 2], # dim 1
[1, 3, 3]]) # dim 2
>>> # Using precomputed offsets (more efficient for repeated calls)
>>> dim_sizes = torch.tensor([4, 5, 6])
>>> offsets = _make_linear_offsets(dim_sizes) # [30, 6, 1]
>>> flat_indices = torch.tensor([[0, 31, 119]])
>>> nd_indices = unflatten_nd_indices(flat_indices, dim_sizes, offsets)
>>> nd_indices
tensor([[0, 1, 3],
[0, 0, 4],
[0, 1, 5]])
"""
if offsets is None:
offsets = _make_linear_offsets(dim_sizes)
assert offsets is not None
N = dim_sizes.numel()
B = flat_indices.size(-1)
out = torch.empty(N, B, device=flat_indices.device, dtype=torch.long)
# integer divide by stride
torch.div(
flat_indices.expand_as(out),
offsets.unsqueeze(1),
rounding_mode="floor",
out=out,
)
# modulus by sizes
torch.remainder(out, dim_sizes.unsqueeze(1), out=out)
return out
flatten_sparse_indices(tensor, start_axis, end_axis)
Flattens a sparse tensor's indices along specified dimensions.
This function takes a sparse tensor and flattens its indices along a contiguous range of dimension. It returns the new indices, the corresponding new shape, and the linear offsets used in the flattening process.
Parameters: |
|
---|
Returns: |
|
---|
Examples:
>>> # Create a 3D sparse tensor
>>> indices = torch.tensor([[0, 1, 2], [0, 1, 0], [1, 2, 0]])
>>> values = torch.tensor([1.0, 2.0, 3.0])
>>> sparse = torch.sparse_coo_tensor(indices, values, (3, 3, 3))
>>> # Flatten dimensions 0 and 1 (first two dimensions)
>>> new_indices, new_shape, offsets = flatten_sparse_indices(sparse, 0, 1)
>>> new_indices
tensor([[0, 4, 6], # Flattened indices for dims 0-1
[1, 2, 0]]) # Unchanged indices for dim 2
>>> new_shape
tensor([9, 3]) # 3x3 flattened to 9, last dim unchanged
>>> offsets
tensor([3, 1])
>>> # Flatten all dimensions
>>> new_indices, new_shape, offsets = flatten_sparse_indices(sparse, 0, 2)
>>> new_indices
tensor([[1, 14, 18]]) # All indices flattened to 1D
>>> new_shape
tensor([27]) # 3x3x3 = 27
>>> # Create a 4D sparse tensor and flatten middle dimensions
>>> indices = torch.tensor([[0, 1], [1, 0], [2, 1], [0, 2]])
>>> values = torch.tensor([1.0, 2.0])
>>> sparse = torch.sparse_coo_tensor(indices, values, (2, 2, 3, 3))
>>>
>>> # Flatten dimensions 1 and 2
>>> new_indices, new_shape, offsets = flatten_sparse_indices(sparse, 1, 2)
>>> new_indices
tensor([[0, 1], # Dim 0 unchanged
[5, 1], # Dims 1-2 flattened: 1*3+2=5, 0*3+1=1
[0, 2]]) # Dim 3 unchanged
>>> new_shape
tensor([2, 6, 3]) # Shape is now [2, 2*3, 3]
Source code in pytorch_sparse_utils/indexing/utils.py
@torch.jit.script
def flatten_sparse_indices(
tensor: Tensor, start_axis: int, end_axis: int
) -> tuple[Tensor, Tensor, Tensor]:
"""Flattens a sparse tensor's indices along specified dimensions.
This function takes a sparse tensor and flattens its indices along a
contiguous range of dimension. It returns the new indices, the
corresponding new shape, and the linear offsets used in the flattening
process.
Args:
tensor (Tensor): The input tensor. Its indices are expected to be in COO format.
start_axis (int): Starting axis (inclusive) of the dimensions to flatten.
end_axis (int): Ending axis (inclusive) of the dimensions to flatten.
Returns:
- new_indices (Tensor): The flattened indices of shape (D, N), where D is
the number of dimensions in the flattened tensor and N is the number
of nonzero elements.
- new_shape (Tensor): The new shape of the flattened tensor of shape (D,)
- dim_linear_offsets (Tensor): The linear offsets used during flattening,
of shape (K,), where K is the number of flattened dimensions.
Examples:
>>> # Create a 3D sparse tensor
>>> indices = torch.tensor([[0, 1, 2], [0, 1, 0], [1, 2, 0]])
>>> values = torch.tensor([1.0, 2.0, 3.0])
>>> sparse = torch.sparse_coo_tensor(indices, values, (3, 3, 3))
>>> # Flatten dimensions 0 and 1 (first two dimensions)
>>> new_indices, new_shape, offsets = flatten_sparse_indices(sparse, 0, 1)
>>> new_indices
tensor([[0, 4, 6], # Flattened indices for dims 0-1
[1, 2, 0]]) # Unchanged indices for dim 2
>>> new_shape
tensor([9, 3]) # 3x3 flattened to 9, last dim unchanged
>>> offsets
tensor([3, 1])
>>> # Flatten all dimensions
>>> new_indices, new_shape, offsets = flatten_sparse_indices(sparse, 0, 2)
>>> new_indices
tensor([[1, 14, 18]]) # All indices flattened to 1D
>>> new_shape
tensor([27]) # 3x3x3 = 27
>>> # Create a 4D sparse tensor and flatten middle dimensions
>>> indices = torch.tensor([[0, 1], [1, 0], [2, 1], [0, 2]])
>>> values = torch.tensor([1.0, 2.0])
>>> sparse = torch.sparse_coo_tensor(indices, values, (2, 2, 3, 3))
>>>
>>> # Flatten dimensions 1 and 2
>>> new_indices, new_shape, offsets = flatten_sparse_indices(sparse, 1, 2)
>>> new_indices
tensor([[0, 1], # Dim 0 unchanged
[5, 1], # Dims 1-2 flattened: 1*3+2=5, 0*3+1=1
[0, 2]]) # Dim 3 unchanged
>>> new_shape
tensor([2, 6, 3]) # Shape is now [2, 2*3, 3]
"""
tensor_indices = tensor.indices() # sparse_dim x nnz (counterintuitive)
indices_to_flatten = tensor_indices[start_axis : end_axis + 1]
# convert shape to tensor since we will be doing math on it.
# it needs to be on the same device as the sparse tensor rather than
# staying on cpu because downstream tensors will be interacting with
# the sparse tensor's indices tensor
shape = torch._shape_as_tensor(tensor).to(tensor.device)
sizes_to_flatten = shape[start_axis : end_axis + 1]
flattened_indices, dim_linear_offsets = flatten_nd_indices(
indices_to_flatten, sizes_to_flatten
)
# make new shape with the flattened axes stacked together
new_shape = torch.cat(
[
shape[:start_axis],
torch.prod(sizes_to_flatten, 0, keepdim=True),
shape[end_axis + 1 :],
]
)
# this assertion shouldn't cause a cpu sync
assert new_shape.size(0) == tensor.ndim - (end_axis - start_axis)
# plug the flattened indices into the existing indices
new_indices = torch.cat(
[tensor_indices[:start_axis], flattened_indices, tensor_indices[end_axis + 1 :]]
)
return new_indices, new_shape, dim_linear_offsets
linearize_sparse_and_index_tensors(sparse_tensor, index_tensor)
Converts multidimensional indices of a sparse tensor and a tensor of indices that we want to retrieve to a shared linearized (flattened) format suitable for fast lookup.
Parameters: |
|
---|
Raises: |
|
---|
Returns: |
|
---|
Examples:
>>> # Create a 2D sparse tensor
>>> indices = torch.tensor([[0, 1, 2], [0, 2, 1]])
>>> values = torch.tensor([1.0, 2.0, 3.0])
>>> sparse = torch.sparse_coo_tensor(indices, values, (3, 3))
>>> # Indices to look up
>>> lookup_indices = torch.tensor([[0, 0], [1, 2], [2, 1]])
>>>
>>> sparse_lin, lookup_lin = linearize_sparse_and_index_tensors(sparse, lookup_indices)
>>> sparse_lin
tensor([0, 5, 7]) # 0*3+0=0, 1*3+2=5, 2*3+1=7
>>> lookup_lin
tensor([0, 5, 7]) # Same linearization scheme
>>> # Batch of indices with different shape
>>> batch_indices = torch.tensor([[[0, 0], [1, 1]],
... [[2, 2], [0, 2]]])
>>> sparse_lin, batch_lin = linearize_sparse_and_index_tensors(sparse, batch_indices)
>>> batch_lin
tensor([0, 4, 8, 2]) # Flattened: 0, 1*3+1=4, 2*3+2=8, 0*3+2=2
>>> # 3D sparse tensor
>>> indices = torch.tensor([[0, 1], [0, 1], [1, 0]])
>>> values = torch.tensor([1.0, 2.0])
>>> sparse = torch.sparse_coo_tensor(indices, values, (2, 2, 2))
>>>
>>> lookup = torch.tensor([[0, 0, 1], [1, 1, 0]])
>>> sparse_lin, lookup_lin = linearize_sparse_and_index_tensors(sparse, lookup)
>>> sparse_lin
tensor([1, 6]) # 0*4+0*2+1=1, 1*4+1*2+0=6
>>> lookup_lin
tensor([1, 6])
Source code in pytorch_sparse_utils/indexing/utils.py
@torch.jit.script
def linearize_sparse_and_index_tensors(
sparse_tensor: Tensor, index_tensor: Tensor
) -> tuple[Tensor, Tensor]:
"""Converts multidimensional indices of a sparse tensor and a tensor of indices
that we want to retrieve to a shared linearized (flattened) format suitable
for fast lookup.
Args:
sparse_tensor (Tensor): torch.sparse_coo_tensor with indices to linearize.
index_tensor (Tensor): Dense tensor with indices matching sparse_tensor's
sparse dims. Can be of any dimension as long as the last dimension
has length equal to the sparse tensor's sparse dimension.
Raises:
ValueError: If the index tensor has a different last dimension than the
sparse tensor's sparse dim.
Returns:
sparse_tensor_indices_linear (Tensor): Linearized version of
sparse_tensor.indices().
index_tensor_linearized (Tensor): Linearized version of index_tensor
with the last dimension squeezed out.
Examples:
>>> # Create a 2D sparse tensor
>>> indices = torch.tensor([[0, 1, 2], [0, 2, 1]])
>>> values = torch.tensor([1.0, 2.0, 3.0])
>>> sparse = torch.sparse_coo_tensor(indices, values, (3, 3))
>>> # Indices to look up
>>> lookup_indices = torch.tensor([[0, 0], [1, 2], [2, 1]])
>>>
>>> sparse_lin, lookup_lin = linearize_sparse_and_index_tensors(sparse, lookup_indices)
>>> sparse_lin
tensor([0, 5, 7]) # 0*3+0=0, 1*3+2=5, 2*3+1=7
>>> lookup_lin
tensor([0, 5, 7]) # Same linearization scheme
>>> # Batch of indices with different shape
>>> batch_indices = torch.tensor([[[0, 0], [1, 1]],
... [[2, 2], [0, 2]]])
>>> sparse_lin, batch_lin = linearize_sparse_and_index_tensors(sparse, batch_indices)
>>> batch_lin
tensor([0, 4, 8, 2]) # Flattened: 0, 1*3+1=4, 2*3+2=8, 0*3+2=2
>>> # 3D sparse tensor
>>> indices = torch.tensor([[0, 1], [0, 1], [1, 0]])
>>> values = torch.tensor([1.0, 2.0])
>>> sparse = torch.sparse_coo_tensor(indices, values, (2, 2, 2))
>>>
>>> lookup = torch.tensor([[0, 0, 1], [1, 1, 0]])
>>> sparse_lin, lookup_lin = linearize_sparse_and_index_tensors(sparse, lookup)
>>> sparse_lin
tensor([1, 6]) # 0*4+0*2+1=1, 1*4+1*2+0=6
>>> lookup_lin
tensor([1, 6])
"""
if index_tensor.shape[-1] != sparse_tensor.sparse_dim():
if (
sparse_tensor.sparse_dim() - 1 == index_tensor.shape[-1]
and sparse_tensor.shape[-1] == 1
and sparse_tensor.dense_dim() == 0
):
# handle case where there's a length-1 trailing sparse dim and the
# index tensor ignores it
sparse_tensor = sparse_tensor[..., 0].coalesce()
else:
raise ValueError(
"Expected last dim of `index_tensor` to be the same as "
"`sparse_tensor.sparse_dim()`, got "
f"{str(index_tensor.shape[-1])} and {sparse_tensor.sparse_dim()}, "
"respectively."
)
sparse_tensor_indices_linear, _, dim_linear_offsets = flatten_sparse_indices(
sparse_tensor, 0, sparse_tensor.sparse_dim() - 1
)
sparse_tensor_indices_linear.squeeze_(0)
# repeat the index flattening for the index tensor. The sparse tensor's indices
# were already flattened in __flattened_indices
index_tensor_linearized = (index_tensor * dim_linear_offsets).sum(-1).view(-1)
return (
sparse_tensor_indices_linear,
index_tensor_linearized,
)
get_sparse_index_mapping(sparse_tensor, index_tensor, sanitize_linear_index_tensor=True)
Finds the locations along a sparse tensor's values tensor for specified sparse indices. Also returns a mask indicating which indices have values actually present in the sparse tensor. It works by flattening the sparse tensor's sparse dims and the index tensor to 1D (and converting n-d indices to raveled indices), then using searchsorted along the flattened sparse tensor indices.
Parameters: |
|
---|
Returns: |
|
---|
Examples:
>>> # Create a sparse tensor
>>> indices = torch.tensor([[0, 1, 2], [0, 1, 0]])
>>> values = torch.tensor([10.0, 20.0, 30.0])
>>> sparse = torch.sparse_coo_tensor(indices, values, (3, 3))
>>> # Look up existing indices
>>> lookup = torch.tensor([[0, 0], [1, 1], [2, 0]])
>>> positions, mask = get_sparse_index_mapping(sparse, lookup)
>>> positions
tensor([0, 1, 2]) # Positions in values tensor
>>> mask
tensor([True, True, True]) # All indices exist
>>> sparse.values()[positions]
tensor([10., 20., 30.])
>>> # Look up mix of existing and non-existing indices
>>> lookup = torch.tensor([[0, 0], [0, 1], [1, 0]])
>>> positions, mask = get_sparse_index_mapping(sparse, lookup)
>>> positions
tensor([0, 0, 0]) # Non-existing indices mapped to 0 (sanitized)
>>> mask
tensor([True, False, False]) # Only first index exists
>>> # With sanitize=False to get insertion positions
>>> positions, mask = get_sparse_index_mapping(sparse, lookup,
... sanitize_linear_index_tensor=False)
>>> positions
tensor([0, 1, 1]) # Position 1 is where [0,1] and [1,0] would be inserted
>>> mask
tensor([True, False, False])
>>> # Batch lookup
>>> batch_lookup = torch.tensor([[[0, 0], [2, 0]],
... [[1, 1], [0, 2]]])
>>> positions, mask = get_sparse_index_mapping(sparse, batch_lookup)
>>> positions
tensor([[0, 2],
[1, 0]])
>>> mask
tensor([[True, True],
[True, False]])
>>> # Out of bounds indices
>>> lookup = torch.tensor([[0, 0], [3, 0], [-1, 0]]) # 3 and -1 are out of bounds
>>> positions, mask = get_sparse_index_mapping(sparse, lookup)
>>> mask
tensor([True, False, False]) # Out of bounds treated as unspecified
Source code in pytorch_sparse_utils/indexing/utils.py
@torch.jit.script
def get_sparse_index_mapping(
sparse_tensor: Tensor,
index_tensor: Tensor,
sanitize_linear_index_tensor: bool = True,
) -> tuple[Tensor, Tensor]:
"""Finds the locations along a sparse tensor's values tensor for specified
sparse indices. Also returns a mask indicating which indices have values
actually present in the sparse tensor. It works by flattening the sparse
tensor's sparse dims and the index tensor to 1D (and converting n-d indices
to raveled indices), then using searchsorted along the flattened sparse
tensor indices.
Args:
sparse_tensor (Tensor): Sparse tensor of dimension ..., M; where ... are
S leading sparse dimensions and M is the dense dimension.
index_tensor (Tensor): Long tensor of dimension ..., S; where ... are
leading batch dimensions. Negative indices and indices outside the
bounds of the sparse dimensions are not supported and will
be considered unspecified, with the corresponding entry in
is_specified_mask being set to False.
sanitize_linear_index_tensor (bool): If False, then the output values at
linear_index_tensor[~is_specified_mask] will be the "insertion position"
that would keep the sparse tensor's indices ordered. This is useful if you
want to insert values, but means that
sparse_tensor.values()[linear_index_tensor] will be potentially unsafe if
some of the "insertion position" values are out of bounds. If this arg is
True, linear_index_tensor[~is_specified_mask] values will be set to 0.
Defaults to True.
Returns:
linear_index_tensor: Long tensor of dimension ... of the locations in
sparse_tensor.values() corresponding to the indices in index_tensor.
Elements where is_specified_mask is False are handled according to the
value of sanitize_linear_index_tensor.
is_specified_mask: Boolean tensor of dimension ... that is True for
indices in index_tensor where values where actually specified in
the sparse tensor and False for indices that were unspecified in
the sparse tensor.
Examples:
>>> # Create a sparse tensor
>>> indices = torch.tensor([[0, 1, 2], [0, 1, 0]])
>>> values = torch.tensor([10.0, 20.0, 30.0])
>>> sparse = torch.sparse_coo_tensor(indices, values, (3, 3))
>>> # Look up existing indices
>>> lookup = torch.tensor([[0, 0], [1, 1], [2, 0]])
>>> positions, mask = get_sparse_index_mapping(sparse, lookup)
>>> positions
tensor([0, 1, 2]) # Positions in values tensor
>>> mask
tensor([True, True, True]) # All indices exist
>>> sparse.values()[positions]
tensor([10., 20., 30.])
>>> # Look up mix of existing and non-existing indices
>>> lookup = torch.tensor([[0, 0], [0, 1], [1, 0]])
>>> positions, mask = get_sparse_index_mapping(sparse, lookup)
>>> positions
tensor([0, 0, 0]) # Non-existing indices mapped to 0 (sanitized)
>>> mask
tensor([True, False, False]) # Only first index exists
>>> # With sanitize=False to get insertion positions
>>> positions, mask = get_sparse_index_mapping(sparse, lookup,
... sanitize_linear_index_tensor=False)
>>> positions
tensor([0, 1, 1]) # Position 1 is where [0,1] and [1,0] would be inserted
>>> mask
tensor([True, False, False])
>>> # Batch lookup
>>> batch_lookup = torch.tensor([[[0, 0], [2, 0]],
... [[1, 1], [0, 2]]])
>>> positions, mask = get_sparse_index_mapping(sparse, batch_lookup)
>>> positions
tensor([[0, 2],
[1, 0]])
>>> mask
tensor([[True, True],
[True, False]])
>>> # Out of bounds indices
>>> lookup = torch.tensor([[0, 0], [3, 0], [-1, 0]]) # 3 and -1 are out of bounds
>>> positions, mask = get_sparse_index_mapping(sparse, lookup)
>>> mask
tensor([True, False, False]) # Out of bounds treated as unspecified
"""
sparse_dim = sparse_tensor.sparse_dim()
sparse_nnz = sparse_tensor._nnz()
sparse_tensor_shape = torch._shape_as_tensor(sparse_tensor).to(
device=index_tensor.device
)
sparse_shape = sparse_tensor_shape[:sparse_dim]
# check for empty sparse tensor
if sparse_nnz == 0:
linear_index_tensor = index_tensor.new_zeros(index_tensor.shape[:-1])
is_specified_mask = index_tensor.new_zeros(
index_tensor.shape[:-1], dtype=torch.bool
)
return linear_index_tensor, is_specified_mask
# Check for out of bounds indices (below 0 or outside tensor dim)
out_of_bounds_indices = torch.any(index_tensor < 0, -1)
out_of_bounds_indices |= torch.any(index_tensor >= sparse_shape, -1)
# put dummy value of 0 in the OOB indices.
# Maybe it'll make the linearization computations and searchsorted faster:
# a compromise between just giving searchsorted random indices to find vs
# causing a cpu sync to call nonzeros to filter them out
index_tensor = index_tensor.masked_fill(out_of_bounds_indices.unsqueeze(-1), 0)
(
sparse_tensor_indices_linearized,
index_tensor_linearized,
) = linearize_sparse_and_index_tensors(sparse_tensor, index_tensor)
# The dummy value of 0 should always return searched index of 0 since
# the sparse_tensor_indices_linearized values are always nonnegative.
# Should be faster to find than random search values.
linear_index_tensor = torch.searchsorted( # binary search
sparse_tensor_indices_linearized, index_tensor_linearized
)
# linear_index_tensor is distinct from index_tensor_linearized in that
# index_tensor_linearized has the flattened version of the index in the sparse
# tensor, while linear_index_tensor has the corresponding index in the sparse
# tensor's values() tensor
# guard against IndexError
if sanitize_linear_index_tensor:
index_clamped = linear_index_tensor.clamp_max_(sparse_nnz - 1)
else:
index_clamped = linear_index_tensor.clamp_max(sparse_nnz - 1)
# Check if the indices were specified by checking for an exact match at the
# resultant searched indices
is_specified_mask: Tensor = (
sparse_tensor_indices_linearized[index_clamped] == index_tensor_linearized
)
is_specified_mask &= ~out_of_bounds_indices.view(-1)
linear_index_tensor = linear_index_tensor.view(index_tensor.shape[:-1])
is_specified_mask = is_specified_mask.view(index_tensor.shape[:-1])
return linear_index_tensor, is_specified_mask
gather_mask_and_fill(values, indices, mask, fill=None)
Efficiently gathers elements from an ND tensor, applies a mask, and fills masked positions.
This function performs the equivalent of
out = values[indices]
out[~mask] = fill.expand_as(out)[~mask] # or 0
but uses torch.index_select for better performance. It retrieves values at the
specified indices and fills positions where the mask is False with either zeros
(default) or the provided fill values.
Parameters: |
|
---|
Returns: |
|
---|
Raises: |
|
---|
Examples:
>>> # Basic usage: gather and mask 1D values
>>> values = torch.tensor([10.0, 20.0, 30.0, 40.0])
>>> indices = torch.tensor([0, 2, 1, 3])
>>> mask = torch.tensor([True, True, False, True])
>>> gather_mask_and_fill(values, indices, mask)
tensor([10., 30., 0., 40.])
>>> # Multi-dimensional values
>>> values = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
>>> indices = torch.tensor([2, 0, 1])
>>> mask = torch.tensor([True, False, True])
>>> gather_mask_and_fill(values, indices, mask)
tensor([[5., 6.],
[0., 0.],
[3., 4.]])
>>> # 2D indices with custom fill
>>> values = torch.tensor([10.0, 20.0, 30.0])
>>> indices = torch.tensor([[0, 1], [2, 0]])
>>> mask = torch.tensor([[True, False], [True, True]])
>>> gather_mask_and_fill(values, indices, mask, fill=torch.tensor(-1.0))
tensor([[10., -1.],
[30., 10.]])
Source code in pytorch_sparse_utils/indexing/utils.py
@torch.jit.script
def gather_mask_and_fill(
values: Tensor, indices: Tensor, mask: Tensor, fill: Optional[Tensor] = None
) -> Tensor:
"""Efficiently gathers elements from an ND tensor, applies a mask, and fills masked
positions.
This function performs the equivalent of
`out = values[indices]
out[~mask] = fill.expand_as(out)[~mask] # or 0
`
but uses torch.index_select for better performance. It retrieves values at the
specified indices and fills positions where the mask is False with either zeros
(default) or the provided fill values.
Args:
values (Tensor): Source tensor to gather from, may be 1D with shape (N)
or n-D with shape (N, D0, D1, ...), where N is the number of elements
and D are potentially multiple feature dimensions.
indices (Tensor): Long tensor of indices into the first dimension of values.
Can be of any shape.
mask (Tensor): Boolean tensor with the same shape as indices. True indicates
positions to keep, False indicates positions to zero out.
fill (Optional[Tensor]): A tensor that must be broadcast-compatible with the
final output shape. It is inserted at positions where `mask` is False.
When None (default), a zero tensor is used.
Returns:
Tensor: The gathered and masked values with shape
(*indices.shape, *values.shape[-1]). Contains values from the source tensor
at the specified indices, with masked positions filled with zeros or from
`fill`.
Raises:
ValueError: If indices and mask have different shapes.
Examples:
>>> # Basic usage: gather and mask 1D values
>>> values = torch.tensor([10.0, 20.0, 30.0, 40.0])
>>> indices = torch.tensor([0, 2, 1, 3])
>>> mask = torch.tensor([True, True, False, True])
>>> gather_mask_and_fill(values, indices, mask)
tensor([10., 30., 0., 40.])
>>> # Multi-dimensional values
>>> values = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
>>> indices = torch.tensor([2, 0, 1])
>>> mask = torch.tensor([True, False, True])
>>> gather_mask_and_fill(values, indices, mask)
tensor([[5., 6.],
[0., 0.],
[3., 4.]])
>>> # 2D indices with custom fill
>>> values = torch.tensor([10.0, 20.0, 30.0])
>>> indices = torch.tensor([[0, 1], [2, 0]])
>>> mask = torch.tensor([[True, False], [True, True]])
>>> gather_mask_and_fill(values, indices, mask, fill=torch.tensor(-1.0))
tensor([[10., -1.],
[30., 10.]])
"""
input_values_1d = False
if values.ndim == 1:
input_values_1d = True
values = values.unsqueeze(1)
if indices.shape != mask.shape:
raise ValueError(
"Expected indices and mask to have same shape, got "
f"{indices.shape} and {mask.shape}"
)
indices_flat = indices.reshape(-1)
mask_flat = mask.reshape(-1)
# figure out how much to broadcast
value_dims = values.shape[1:]
n_value_dims = values.ndim - 1
new_shape = indices.shape
if not input_values_1d:
new_shape += value_dims
# pre-mask the indices to guard against unsafe indices in the masked portion
indices_flat = torch.where(mask_flat, indices_flat, torch.zeros_like(indices_flat))
selected = values.index_select(0, indices_flat)
# unsqueeze mask
mask_flat = mask_flat.view((mask_flat.size(0),) + (1,) * n_value_dims)
if fill is None:
selected.masked_fill_(~mask_flat, 0)
else:
fill_broadcast = fill.expand(new_shape).reshape(selected.shape)
selected = torch.where(mask_flat, selected, fill_broadcast)
selected = selected.reshape(new_shape)
return selected