Third-party sparse library integration
pytorch-sparse-utils
features integrations with three major libraries for sparse arrays and tensors:
- Pydata sparse, a numpy-like sparse array implementation with close numpy integration.
- MinkowskiEngine, an Nvidia library for convolutions on sparse tensors.
- spconv, another library for convolutions on sparse tensors.
All three libraries feature their own sparse tensor object formats that are distinct from the built-in PyTorch sparse tensors. pytorch_sparse_utils
's conversion
module provides simple utilities to convert between the formats.
Using these conversion utilities allows for, for example, a pipeline where images are loaded as Pydata Sparse COO arrays, converted to PyTorch sparse tensors in a torch.utils.DataLoader
, converted to MinkowskiEngine
SparseTensors
for processing through a MinkowskiEngine
CNN backbone, then converted back to PyTorch sparse tensors for processing with a Transformer module.
Pydata sparse conversions
torch_sparse_to_pydata_sparse(tensor)
Converts a sparse torch.Tensor to an equivalent Pydata sparse COO array
Parameters: |
|
---|
Returns: |
|
---|
Source code in pytorch_sparse_utils/conversion.py
def torch_sparse_to_pydata_sparse(tensor: Tensor) -> sparse.COO:
"""Converts a sparse torch.Tensor to an equivalent Pydata sparse COO array
Args:
tensor (torch.Tensor): Sparse tensor to be converted
Returns:
array (sparse.COO): Pydata sparse COO array
"""
assert tensor.is_sparse
tensor = tensor.detach().cpu().coalesce()
assert tensor.is_coalesced
nonzero_values = tensor.values().nonzero(as_tuple=True)
return sparse.COO(
tensor.indices()[:, nonzero_values[0]].numpy(),
tensor.values()[nonzero_values].numpy(),
tensor.shape,
has_duplicates=False,
)
pydata_sparse_to_torch_sparse(sparse_array, device=None)
Converts a Pydata sparse COO array to an equivalent sparse torch.Tensor
Parameters: |
|
---|
Returns: |
|
---|
Source code in pytorch_sparse_utils/conversion.py
def pydata_sparse_to_torch_sparse(
sparse_array: sparse.COO, device: Optional[Union[str, torch.device]] = None
) -> Tensor:
"""Converts a Pydata sparse COO array to an equivalent sparse torch.Tensor
Args:
sparse_array (sparse.COO): Pydata sparse COO array to be converted
device (Optional[Union[str, torch.device]]): Device on which to create the
sparse tensor. Defaults to None (default device).
Returns:
tensor (torch.Tensor): Converted sparse tensor
"""
return torch.sparse_coo_tensor(
indices=sparse_array.coords, # pyright: ignore[reportArgumentType]
values=sparse_array.data, # pyright: ignore[reportArgumentType]
size=sparse_array.shape,
device=device,
).coalesce()
MinkowskiEngine conversions
torch_sparse_to_minkowski(tensor)
Converts a sparse torch.Tensor to an equivalent MinkowskiEngine SparseTensor
Parameters: |
|
---|
Returns: |
|
---|
Source code in pytorch_sparse_utils/conversion.py
@imports.requires_minkowskiengine
def torch_sparse_to_minkowski(tensor: Tensor):
"""Converts a sparse torch.Tensor to an equivalent MinkowskiEngine SparseTensor
Args:
tensor (torch.Tensor): Sparse tensor to be converted
Returns:
sparse_tensor (MinkowskiEngine.SparseTensor): Converted sparse tensor
"""
assert isinstance(tensor, Tensor)
assert tensor.is_sparse
features = tensor.values()
if features.ndim == 1:
features = features.unsqueeze(-1)
coordinates = tensor.indices().T.int().contiguous()
return ME.SparseTensor(
features, coordinates, requires_grad=tensor.requires_grad, device=tensor.device
)
minkowski_to_torch_sparse(tensor, full_scale_spatial_shape=None, squeeze=False)
Converts a MinkowskiEngine SparseTensor to an equivalent sparse torch.Tensor
Parameters: |
|
---|
Returns: |
|
---|
Source code in pytorch_sparse_utils/conversion.py
@imports.requires_minkowskiengine
def minkowski_to_torch_sparse(
tensor: Union[Tensor, ME.SparseTensor],
full_scale_spatial_shape: Optional[Union[Tensor, list[int]]] = None,
squeeze: bool = False
) -> Tensor:
"""Converts a MinkowskiEngine SparseTensor to an equivalent sparse torch.Tensor
Args:
tensor (MinkowskiEngine.SparseTensor): Sparse tensor to be converted
full_scale_spatial_shape (Optional[Union[list[int], [Tensor]]]): The full
extent of the spatial domain on which the sparse data reside.
If given, will be used to define the size of the sparse tensor. If not
given, the size will be inferred from the indices in the tensor.
Default: None
squeeze (bool): If True and the feature dimension of the MinkowskiEngine
SparseTensor is 1, the returned sparse torch.Tensor will have its values
squeezed to 1D shape of [nnz] rather than [nnz, 1]. Raises an error if
True and the feature dimension is not 1.
Returns:
tensor (torch.Tensor): Converted sparse tensor
"""
if isinstance(tensor, Tensor):
assert tensor.is_sparse
return tensor
assert isinstance(tensor, ME.SparseTensor)
min_coords = torch.zeros([tensor.dimension], dtype=torch.int, device=tensor.device)
if full_scale_spatial_shape is not None:
if isinstance(full_scale_spatial_shape, list):
max_coords = torch.tensor(
full_scale_spatial_shape, dtype=torch.int, device=tensor.device
)
else:
assert isinstance(full_scale_spatial_shape, Tensor)
max_coords = full_scale_spatial_shape.to(tensor.C)
else:
max_coords = None
out = __me_sparse(tensor, min_coords, max_coords)[0].coalesce()
if squeeze:
if out.values().shape[1] != 1:
raise ValueError(
"Got `squeeze`=True, but the MinkowskiEngine tensor has a feature "
f"dim of {out.values().shape[1]}, not 1."
)
out = torch.sparse_coo_tensor(
out.indices(),
out.values().squeeze(-1),
out.shape[:-1],
is_coalesced=out.is_coalesced()
)
return out
spconv conversions
torch_sparse_to_spconv(tensor)
Converts a sparse torch.Tensor to an equivalent spconv SparseConvTensor
Parameters: |
|
---|
Returns: |
|
---|
Source code in pytorch_sparse_utils/conversion.py
@imports.requires_spconv
def torch_sparse_to_spconv(tensor: torch.Tensor):
"""Converts a sparse torch.Tensor to an equivalent spconv SparseConvTensor
Args:
tensor (torch.Tensor): Sparse tensor to be converted
Returns:
SparseConvTensor (spconv.SparseConvTensor): Converted spconv tensor
"""
if isinstance(tensor, spconv.SparseConvTensor):
return tensor
assert tensor.is_sparse
spatial_shape = list(tensor.shape[1:-1])
batch_size = tensor.shape[0]
indices_th = tensor.indices()
features_th = tensor.values()
if features_th.ndim == 1:
# Tensor has scalar features, but spconv always expects 2D feature tensor
features_th = features_th.unsqueeze(-1)
spatial_shape = spatial_shape + [tensor.shape[-1]]
indices_th = indices_th.permute(1, 0).contiguous().int()
return spconv.SparseConvTensor(features_th, indices_th, spatial_shape, batch_size)
spconv_to_torch_sparse(tensor, squeeze=False)
Converts an spconv SparseConvTensor to a sparse torch.Tensor
Parameters: |
|
---|
Returns: |
|
---|
Source code in pytorch_sparse_utils/conversion.py
@imports.requires_spconv
def spconv_to_torch_sparse(tensor, squeeze=False) -> Tensor:
"""Converts an spconv SparseConvTensor to a sparse torch.Tensor
Args:
tensor (spconv.SparseConvTensor): spconv tensor to be converted
squeeze (bool): If the spconv tensor has a feature dimension of 1,
setting this to true squeezes it out so that the resulting
sparse Tensor has a dense_dim() of 0. Raises an error if the spconv
feature dim is not 1.
Returns:
tensor (Tensor): Converted sparse torch.Tensor
"""
if isinstance(tensor, Tensor) and tensor.is_sparse:
return tensor
assert isinstance(tensor, spconv.SparseConvTensor)
if squeeze:
if tensor.features.shape[-1] != 1:
raise ValueError(
"Got `squeeze`=True, but the spconv tensor has a feature dim of "
f"{tensor.features.shape[-1]}, not 1"
)
size = [tensor.batch_size] + tensor.spatial_shape
values = tensor.features.squeeze(-1)
else:
size = [tensor.batch_size] + tensor.spatial_shape + [tensor.features.shape[-1]]
values = tensor.features
indices = tensor.indices.transpose(0, 1)
out = torch.sparse_coo_tensor(
indices,
values,
size,
device=tensor.features.device,
dtype=tensor.features.dtype,
requires_grad=tensor.features.requires_grad,
check_invariants=True,
)
out = out.coalesce()
return out