Third-party sparse library integration

pytorch-sparse-utils features integrations with three major libraries for sparse arrays and tensors:

  • Pydata sparse, a numpy-like sparse array implementation with close numpy integration.
  • MinkowskiEngine, an Nvidia library for convolutions on sparse tensors.
  • spconv, another library for convolutions on sparse tensors.

All three libraries feature their own sparse tensor object formats that are distinct from the built-in PyTorch sparse tensors. pytorch_sparse_utils's conversion module provides simple utilities to convert between the formats.

Using these conversion utilities allows for, for example, a pipeline where images are loaded as Pydata Sparse COO arrays, converted to PyTorch sparse tensors in a torch.utils.DataLoader, converted to MinkowskiEngine SparseTensors for processing through a MinkowskiEngine CNN backbone, then converted back to PyTorch sparse tensors for processing with a Transformer module.


Pydata sparse conversions

torch_sparse_to_pydata_sparse(tensor)

Converts a sparse torch.Tensor to an equivalent Pydata sparse COO array

Parameters:
  • tensor (Tensor) –

    Sparse tensor to be converted

Returns:
  • array( COO ) –

    Pydata sparse COO array

Source code in pytorch_sparse_utils/conversion.py
def torch_sparse_to_pydata_sparse(tensor: Tensor) -> sparse.COO:
    """Converts a sparse torch.Tensor to an equivalent Pydata sparse COO array

    Args:
        tensor (torch.Tensor): Sparse tensor to be converted

    Returns:
        array (sparse.COO): Pydata sparse COO array
    """
    assert tensor.is_sparse
    tensor = tensor.detach().cpu().coalesce()
    assert tensor.is_coalesced
    nonzero_values = tensor.values().nonzero(as_tuple=True)
    return sparse.COO(
        tensor.indices()[:, nonzero_values[0]].numpy(),
        tensor.values()[nonzero_values].numpy(),
        tensor.shape,
        has_duplicates=False,
    )

pydata_sparse_to_torch_sparse(sparse_array, device=None)

Converts a Pydata sparse COO array to an equivalent sparse torch.Tensor

Parameters:
  • sparse_array (COO) –

    Pydata sparse COO array to be converted

  • device (Optional[Union[str, device]], default: None ) –

    Device on which to create the sparse tensor. Defaults to None (default device).

Returns:
  • tensor( Tensor ) –

    Converted sparse tensor

Source code in pytorch_sparse_utils/conversion.py
def pydata_sparse_to_torch_sparse(
    sparse_array: sparse.COO, device: Optional[Union[str, torch.device]] = None
) -> Tensor:
    """Converts a Pydata sparse COO array to an equivalent sparse torch.Tensor

    Args:
        sparse_array (sparse.COO): Pydata sparse COO array to be converted
        device (Optional[Union[str, torch.device]]): Device on which to create the
            sparse tensor. Defaults to None (default device).

    Returns:
        tensor (torch.Tensor): Converted sparse tensor
    """
    return torch.sparse_coo_tensor(
        indices=sparse_array.coords,  # pyright: ignore[reportArgumentType]
        values=sparse_array.data,  # pyright: ignore[reportArgumentType]
        size=sparse_array.shape,
        device=device,
    ).coalesce()

MinkowskiEngine conversions

torch_sparse_to_minkowski(tensor)

Converts a sparse torch.Tensor to an equivalent MinkowskiEngine SparseTensor

Parameters:
  • tensor (Tensor) –

    Sparse tensor to be converted

Returns:
  • sparse_tensor( SparseTensor ) –

    Converted sparse tensor

Source code in pytorch_sparse_utils/conversion.py
@imports.requires_minkowskiengine
def torch_sparse_to_minkowski(tensor: Tensor):
    """Converts a sparse torch.Tensor to an equivalent MinkowskiEngine SparseTensor

    Args:
        tensor (torch.Tensor): Sparse tensor to be converted

    Returns:
        sparse_tensor (MinkowskiEngine.SparseTensor): Converted sparse tensor
    """
    assert isinstance(tensor, Tensor)
    assert tensor.is_sparse
    features = tensor.values()
    if features.ndim == 1:
        features = features.unsqueeze(-1)
    coordinates = tensor.indices().T.int().contiguous()
    return ME.SparseTensor(
        features, coordinates, requires_grad=tensor.requires_grad, device=tensor.device
    )

minkowski_to_torch_sparse(tensor, full_scale_spatial_shape=None, squeeze=False)

Converts a MinkowskiEngine SparseTensor to an equivalent sparse torch.Tensor

Parameters:
  • tensor (SparseTensor) –

    Sparse tensor to be converted

  • full_scale_spatial_shape (Optional[Union[list[int], [Tensor]]], default: None ) –

    The full extent of the spatial domain on which the sparse data reside. If given, will be used to define the size of the sparse tensor. If not given, the size will be inferred from the indices in the tensor. Default: None

  • squeeze (bool, default: False ) –

    If True and the feature dimension of the MinkowskiEngine SparseTensor is 1, the returned sparse torch.Tensor will have its values squeezed to 1D shape of [nnz] rather than [nnz, 1]. Raises an error if True and the feature dimension is not 1.

Returns:
  • tensor( Tensor ) –

    Converted sparse tensor

Source code in pytorch_sparse_utils/conversion.py
@imports.requires_minkowskiengine
def minkowski_to_torch_sparse(
    tensor: Union[Tensor, ME.SparseTensor],
    full_scale_spatial_shape: Optional[Union[Tensor, list[int]]] = None,
    squeeze: bool = False
) -> Tensor:
    """Converts a MinkowskiEngine SparseTensor to an equivalent sparse torch.Tensor

    Args:
        tensor (MinkowskiEngine.SparseTensor): Sparse tensor to be converted
        full_scale_spatial_shape (Optional[Union[list[int], [Tensor]]]): The full
            extent of the spatial domain on which the sparse data reside.
            If given, will be used to define the size of the sparse tensor. If not
            given, the size will be inferred from the indices in the tensor.
            Default: None
        squeeze (bool): If True and the feature dimension of the MinkowskiEngine
            SparseTensor is 1, the returned sparse torch.Tensor will have its values
            squeezed to 1D shape of [nnz] rather than [nnz, 1]. Raises an error if
            True and the feature dimension is not 1.

    Returns:
        tensor (torch.Tensor): Converted sparse tensor
    """
    if isinstance(tensor, Tensor):
        assert tensor.is_sparse
        return tensor
    assert isinstance(tensor, ME.SparseTensor)
    min_coords = torch.zeros([tensor.dimension], dtype=torch.int, device=tensor.device)
    if full_scale_spatial_shape is not None:
        if isinstance(full_scale_spatial_shape, list):
            max_coords = torch.tensor(
                full_scale_spatial_shape, dtype=torch.int, device=tensor.device
            )
        else:
            assert isinstance(full_scale_spatial_shape, Tensor)
            max_coords = full_scale_spatial_shape.to(tensor.C)
    else:
        max_coords = None
    out = __me_sparse(tensor, min_coords, max_coords)[0].coalesce()
    if squeeze:
        if out.values().shape[1] != 1:
            raise ValueError(
                "Got `squeeze`=True, but the MinkowskiEngine tensor has a feature "
                f"dim of {out.values().shape[1]}, not 1."
            )
        out = torch.sparse_coo_tensor(
            out.indices(),
            out.values().squeeze(-1),
            out.shape[:-1],
            is_coalesced=out.is_coalesced()
        )
    return out

spconv conversions

torch_sparse_to_spconv(tensor)

Converts a sparse torch.Tensor to an equivalent spconv SparseConvTensor

Parameters:
  • tensor (Tensor) –

    Sparse tensor to be converted

Returns:
  • SparseConvTensor( SparseConvTensor ) –

    Converted spconv tensor

Source code in pytorch_sparse_utils/conversion.py
@imports.requires_spconv
def torch_sparse_to_spconv(tensor: torch.Tensor):
    """Converts a sparse torch.Tensor to an equivalent spconv SparseConvTensor

    Args:
        tensor (torch.Tensor): Sparse tensor to be converted

    Returns:
        SparseConvTensor (spconv.SparseConvTensor): Converted spconv tensor
    """
    if isinstance(tensor, spconv.SparseConvTensor):
        return tensor
    assert tensor.is_sparse
    spatial_shape = list(tensor.shape[1:-1])
    batch_size = tensor.shape[0]
    indices_th = tensor.indices()
    features_th = tensor.values()
    if features_th.ndim == 1:
        # Tensor has scalar features, but spconv always expects 2D feature tensor
        features_th = features_th.unsqueeze(-1)
        spatial_shape = spatial_shape + [tensor.shape[-1]]
    indices_th = indices_th.permute(1, 0).contiguous().int()
    return spconv.SparseConvTensor(features_th, indices_th, spatial_shape, batch_size)

spconv_to_torch_sparse(tensor, squeeze=False)

Converts an spconv SparseConvTensor to a sparse torch.Tensor

Parameters:
  • tensor (SparseConvTensor) –

    spconv tensor to be converted

  • squeeze (bool, default: False ) –

    If the spconv tensor has a feature dimension of 1, setting this to true squeezes it out so that the resulting sparse Tensor has a dense_dim() of 0. Raises an error if the spconv feature dim is not 1.

Returns:
  • tensor( Tensor ) –

    Converted sparse torch.Tensor

Source code in pytorch_sparse_utils/conversion.py
@imports.requires_spconv
def spconv_to_torch_sparse(tensor, squeeze=False) -> Tensor:
    """Converts an spconv SparseConvTensor to a sparse torch.Tensor

    Args:
        tensor (spconv.SparseConvTensor): spconv tensor to be converted
        squeeze (bool): If the spconv tensor has a feature dimension of 1,
            setting this to true squeezes it out so that the resulting
            sparse Tensor has a dense_dim() of 0. Raises an error if the spconv
            feature dim is not 1.

    Returns:
        tensor (Tensor): Converted sparse torch.Tensor
    """
    if isinstance(tensor, Tensor) and tensor.is_sparse:
        return tensor
    assert isinstance(tensor, spconv.SparseConvTensor)
    if squeeze:
        if tensor.features.shape[-1] != 1:
            raise ValueError(
                "Got `squeeze`=True, but the spconv tensor has a feature dim of "
                f"{tensor.features.shape[-1]}, not 1"
            )
        size = [tensor.batch_size] + tensor.spatial_shape
        values = tensor.features.squeeze(-1)
    else:
        size = [tensor.batch_size] + tensor.spatial_shape + [tensor.features.shape[-1]]
        values = tensor.features
    indices = tensor.indices.transpose(0, 1)
    out = torch.sparse_coo_tensor(
        indices,
        values,
        size,
        device=tensor.features.device,
        dtype=tensor.features.dtype,
        requires_grad=tensor.features.requires_grad,
        check_invariants=True,
    )
    out = out.coalesce()
    return out