Utility functions
Assorted utility functions for sparse and irregularly-structured tensors.
Vector lexicographical sorting
lexsort_nd(tensor, vector_dim, sort_dim, descending=False, stable=False, force_robust=False)
Sorts a tensor of vectors in lexicographic order.
Given a tensor of vectors, performs a sort that orders the vectors
in lexicographic order. The vectors are defined along the vector_dim
dimension,
and sorted along the sort_dim
dimension.
If force_robust
is False, then a fast lexicographic sort based on projecting the
vectors to an order-preserving 1D basis is used if possible, falling back to a
"robust" (true) multi-pass lexicographic sort if the input vectors cannot be
losslessly compressed to 1D. If force_robust
is True, the robust sort is always
used.
Both integer and floating-point tensors are supported.
Parameters: |
|
---|
Returns: |
|
---|
Notes
- The relationship between the sorted tensor and the sort indices is: sort_indices_exp = sort_indices.unsqueeze(vector_dim).expand_as(tensor) sorted_tensor = tensor.gather(sort_dim, sort_indices_exp).
Source code in pytorch_sparse_utils/utils/lexsort_nd.py
@torch.jit.script
def lexsort_nd(
tensor: Tensor,
vector_dim: int,
sort_dim: int,
descending: bool = False,
stable: bool = False,
force_robust: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Sorts a tensor of vectors in lexicographic order.
Given a tensor of vectors, performs a sort that orders the vectors
in lexicographic order. The vectors are defined along the `vector_dim` dimension,
and sorted along the `sort_dim` dimension.
If `force_robust` is False, then a fast lexicographic sort based on projecting the
vectors to an order-preserving 1D basis is used if possible, falling back to a
"robust" (true) multi-pass lexicographic sort if the input vectors cannot be
losslessly compressed to 1D. If `force_robust` is True, the robust sort is always
used.
Both integer and floating-point tensors are supported.
Args:
tensor (Tensor): Tensor to be sorted.
vector_dim (int): Index along which vectors are defined.
sort_dim (int): Index along which to sort.
descending (bool): If True, vectors are sorted in descending order. Default: False.
stable (bool): If True, stable sort is always used (order of equivalent values is kept).
If False, unstable sorts are used when possible.
force_robust (bool): If True, always use the "true" iterative lexsort. This requires
tensor.shape[vector_dim] sorts instead of 1 sort, but is more reproducible.
Returns:
tuple[Tensor, Tensor]:
- Tensor: Sorted tensor.
- Tensor: Sort indices.
Notes:
- The relationship between the sorted tensor and the sort indices is:
sort_indices_exp = sort_indices.unsqueeze(vector_dim).expand_as(tensor)
sorted_tensor = tensor.gather(sort_dim, sort_indices_exp).
"""
# Normalize dims
ndim = tensor.ndim
vector_dim = vector_dim if vector_dim >= 0 else vector_dim + ndim
sort_dim = sort_dim if sort_dim >= 0 else sort_dim + ndim
# Input checks
if vector_dim < 0 or vector_dim >= ndim:
raise ValueError(
f"Normalized key_dim {vector_dim} is out of bounds for tensor with {ndim} "
"dimensions."
)
if sort_dim < 0 or sort_dim >= ndim:
raise ValueError(
f"Normalized sort_dim {sort_dim} is out of bounds for tensor with {ndim} "
"dimensions."
)
if sort_dim == vector_dim:
raise ValueError(
f"Expected vector_dim and sort_dim to be different, but got both "
f"= {sort_dim}"
)
if tensor.isnan().any():
raise ValueError("Tensor has nan values.")
if tensor.isinf().any():
raise ValueError("Tensor has infinite values.")
# Get vector length
vector_len = tensor.shape[vector_dim]
# Handle edge cases
if tensor.numel() == 0:
indices_shape = list(tensor.shape)
indices_shape.pop(vector_dim)
return tensor, torch.zeros(
indices_shape, device=tensor.device, dtype=torch.long
)
if tensor.size(sort_dim) == 1:
indices_shape = list(tensor.shape)
indices_shape.pop(vector_dim)
return tensor, torch.zeros(
indices_shape, device=tensor.device, dtype=torch.long
)
if vector_len == 1: # Just do regular sort
tensor, sort_indices = torch.sort(
tensor,
dim=sort_dim,
descending=descending,
stable=stable,
)
sort_indices = sort_indices.squeeze(vector_dim)
return tensor, sort_indices
# Move vector_dim to last position for projection reduction
# and sort_dim to first position for faster sorting
tensor_permuted, perm = _permute_dims(tensor, vector_dim, sort_dim)
tensor_permuted = tensor_permuted.contiguous()
# List of integer types
_INT_TYPES = (
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
)
# Pick appropriate sorting subroutine
if force_robust:
sorted_tensor_permuted, indices = _lexsort_nd_robust(
tensor_permuted, descending=descending
)
elif torch.is_floating_point(tensor_permuted):
sorted_tensor_permuted, indices = _lexsort_nd_float(
tensor_permuted, descending, stable
)
elif tensor_permuted.dtype in _INT_TYPES:
indices = _lexsort_nd_int(tensor_permuted, descending, stable).sort_indices
sorted_tensor_permuted = None
else:
raise ValueError(f"Unsupported tensor dtype {tensor.dtype}")
# Gather from the original tensor using the sort indices
indices_unsq = indices.unsqueeze(-1) # add singleton dim at permuted vector dim
if sorted_tensor_permuted is None: # get sorted tensor if not returned already
sorted_tensor_permuted = torch.gather(
tensor_permuted, dim=0, index=indices_unsq.expand_as(tensor_permuted)
)
# Permute tensor and indices back to original dimension order
inverse_perm = [0] * tensor.ndim
for i, p in enumerate(perm):
inverse_perm[p] = i
sorted_tensor = sorted_tensor_permuted.permute(inverse_perm)
sort_indices = indices_unsq.permute(inverse_perm).squeeze(vector_dim)
return sorted_tensor, sort_indices
Concatenated-batch topk
BatchTopK
Bases: NamedTuple
Source code in pytorch_sparse_utils/utils/batch_topk.py
class BatchTopK(NamedTuple):
indices: Tensor
offsets: Tensor
values: Optional[Tensor] = None
batch_topk(tensor, batch_offsets, k, dim=0, largest=True, sorted=True, return_values=False)
Performs top-k operation on a batch-concatenated tensor with variable sequence lengths.
This function handles both uniform-length sequences (where a more efficient batch operation is used) and variable-length sequences (where per-batch processing occurs). The function returns indices adjusted to the original tensor's indexing space along with offsets to identify each batch's results.
Parameters: |
|
---|
Returns: |
|
---|
Usage Notes
-
Slice per sequence The results for the b-th input sequence are obtained with
idx_b = out.indices[out.offsets[b] : out.offsets[b + 1]] if out.values is not None: val_b = out.values[out.offsets[b] : out.offsets[b + 1]]
where
out = batch_topk(...)
. -
Expected slice length Let seq_len_b = batch_offsets[b + 1] - batch_offsets[b] k_b = min(k, seq_len_b) # or k[b] if k is per-sample
length(idx_b) is
k_b * prod(tensor.size(i) for i not in {0, dim}) if dim == 0 k_b * seq_len_b * prod(tensor.size(i) for i not in {0, dim}) otherwise
This is exactly the number of elements returned by a regular torch.topk call on that subsequence.
-
Reshaping to the usual top-k shape After selecting
subseq_b = tensor[batch_offsets[b] : batch_offsets[b + 1]]
you can turn the flat index slice back into the layout produced by torch.topk:out_shape = list(subseq_b.shape) out_shape[dim] = k_b idx_b = idx_b.view(*out_shape)
The same
out_shape
works forval_b
when it is present. -
Recovering the values when they were not returned If
return_values=False
, the values can still be gathered later.-
When
dim != 0
the indices are local to every subsequence, therefore you can gather directly:val_b = torch.take_along_dim(subseq_b, idx_b, dim)
-
When
dim == 0
the indices are expressed in the global (concatenated-tensor) space. Either use them on the original concatenated tensor,val_b = torch.take_along_dim(tensor, idx_b, dim)
or convert them back to local coordinates first:
idx_b_local = idx_b - batch_offsets[b] val_b = torch.take_along_dim(subseq_b, idx_b_local, dim)
-
Source code in pytorch_sparse_utils/utils/batch_topk.py
@torch.jit.script
def batch_topk(
tensor: Tensor,
batch_offsets: Tensor,
k: Union[Tensor, list[int], int],
dim: int = 0,
largest: bool = True,
sorted: bool = True,
return_values: bool = False,
) -> BatchTopK:
"""
Performs top-k operation on a batch-concatenated tensor with variable sequence lengths.
This function handles both uniform-length sequences (where a more efficient batch
operation is used) and variable-length sequences (where per-batch processing occurs).
The function returns indices adjusted to the original tensor's indexing space
along with offsets to identify each batch's results.
Args:
tensor (Tensor): A batch-concatenated tensor of shape (total_length, d1, d2, ...)
where total_length is the sum of all sequence lengths.
batch_offsets (Tensor): A 1D tensor of indices indicating where each sequence
begins in the batch-concatenated tensor. Should be of shape (batch_size + 1,)
with the last element being the total length.
k (Union[Tensor, int]): Number of top elements to select. Can be an integer
for the same k across all batches, or a tensor or list for different k per
batch. Will be clamped to each sequence's length if k > sequence_length.
dim (int, optional): Dimension along which to perform the top-k operation for
each concatenated subsequence. Default: 0 (sequence dimension).
largest (bool, optional): If True, returns the indices of the largest elements.
If False, returns those of the smallest elements. Default: True.
sorted (bool, optional): If True, always returns the elements in sorted order.
For technical reasons, the returned elements may be sorted in some cases
even when False. Default: True.
return_values (bool, optional): If True, the output namedtuple will include the
topk values in addition to the indices and offsets. Default: False.
Returns:
BatchTopK: A namedtuple containing:
- indices (Tensor): 1-D long tensor of the top-k indices in the original
concatenated tensor space.
- offsets (Tensor): 1-D long tensor of offsets with shape (batch_size + 1)
indicating where each batch's results begin in the topk_indices tensor.
Like the input batch_offsets,
BatchTopK.offsets[-1] == len(BatchTopK.indices).
- values (Optional[Tensor]): 1-D tensor of the same length as `indices` with
the actual topk values. Returned only if return_values if True, otherwise
is None.
Usage Notes:
1. Slice per sequence
The results for the b-th input sequence are obtained with
idx_b = out.indices[out.offsets[b] : out.offsets[b + 1]]
if out.values is not None:
val_b = out.values[out.offsets[b] : out.offsets[b + 1]]
where `out = batch_topk(...)`.
2. Expected slice length
Let
seq_len_b = batch_offsets[b + 1] - batch_offsets[b]
k_b = min(k, seq_len_b) # or k[b] if k is per-sample
length(idx_b) is
k_b * prod(tensor.size(i) for i not in {0, dim}) if dim == 0
k_b * seq_len_b * prod(tensor.size(i) for i not in {0, dim}) otherwise
This is exactly the number of elements returned by a
regular torch.topk call on that subsequence.
3. Reshaping to the usual top-k shape
After selecting `subseq_b = tensor[batch_offsets[b] : batch_offsets[b + 1]]`
you can turn the flat index slice back into the layout produced by
torch.topk:
out_shape = list(subseq_b.shape)
out_shape[dim] = k_b
idx_b = idx_b.view(*out_shape)
The same `out_shape` works for `val_b` when it is present.
4. Recovering the values when they were not returned
If `return_values=False`, the values can still be gathered later.
- When ``dim != 0`` the indices are **local** to every subsequence,
therefore you can gather directly:
val_b = torch.take_along_dim(subseq_b, idx_b, dim)
- When ``dim == 0`` the indices are expressed in the **global**
(concatenated-tensor) space. Either use them on the original
concatenated tensor,
val_b = torch.take_along_dim(tensor, idx_b, dim)
or convert them back to local coordinates first:
idx_b_local = idx_b - batch_offsets[b]
val_b = torch.take_along_dim(subseq_b, idx_b_local, dim)
"""
seq_lens = batch_offsets_to_seq_lengths(batch_offsets)
assert isinstance(seq_lens, Tensor)
batch_size = seq_lens.numel()
# Normalize k
k = _normalize_k(k, batch_size, tensor.device)
assert isinstance(k, Tensor)
if torch.any(k < 0):
raise ValueError(f"Expected nonnegative value for `k`, got {k}")
# Normalize dim if negative
dim = dim if dim >= 0 else dim + tensor.ndim
if dim < 0 or dim >= tensor.dim():
raise ValueError(
"Normalized dimension "
f"{dim} is out of bounds for tensor with {tensor.ndim} dimensions"
)
# Early return for empty tensor
if batch_size == 0 or tensor.numel() == 0:
topk_indices = torch.empty(0, device=tensor.device, dtype=torch.long)
topk_offsets = torch.zeros(1, device=tensor.device, dtype=torch.long)
topk_values = tensor[0:0].flatten() if return_values else None # keep gradients
return BatchTopK(topk_indices, topk_offsets, topk_values)
# Find product of dims besides seq and topk dims
prod_extra_dims = 1
for i, s in enumerate(tensor.shape):
if i not in (0, dim):
prod_extra_dims *= s
same_len = torch.all(seq_lens == seq_lens[0])
same_k = torch.all(k == k[0])
if same_len: # sequences same length, batch for efficiency
# Clamp k to seq length
seq_len_int = int(seq_lens[0].item())
if dim == 0:
k_max = torch.min(k.amax(), seq_lens[0])
else:
k_max = k.amax().clamp_max(tensor.shape[dim])
k = k.clamp_max(k_max)
k_max_int = int(k_max.item())
# Compute per-batch result size
if dim == 0:
out_sizes = k * prod_extra_dims
else:
out_sizes = k * seq_len_int * prod_extra_dims
topk_offsets = seq_lengths_to_batch_offsets(out_sizes)
assert isinstance(topk_offsets, Tensor)
if k_max_int == 0:
topk_indices = tensor.new_empty(0, dtype=torch.long)
topk_values = tensor[0:0].flatten() if return_values else None
return BatchTopK(topk_indices, topk_offsets, topk_values)
# reshape to [bsz, seq_len, ...]
batch_shape = (batch_size, seq_len_int) + tensor.shape[1:]
topk_dim = dim + 1 # account for new leading batch dim
values_all, indices_all = tensor.reshape(batch_shape).topk(
k_max_int, topk_dim, largest=largest, sorted=True
) # Need to be sorted to be able to select first k for each subseq
# If topk is along sequence length, need to add offsets to indices
# to globalize them
if dim == 0:
unsqueeze_shape = (batch_size,) + (1,) * (indices_all.ndim - 1)
indices_all = indices_all + batch_offsets[:-1].view(unsqueeze_shape)
if same_k:
topk_indices = indices_all.flatten()
topk_values = values_all.flatten() if return_values else None
return BatchTopK(topk_indices, topk_offsets, topk_values)
# not all same k: slice into the topk output for each batch
total_len = int(topk_offsets[-1])
topk_indices = tensor.new_empty(total_len, dtype=torch.long)
topk_values = tensor.new_empty(total_len) if return_values else None
for b in range(batch_size):
k_b = int(k[b])
if k_b == 0:
continue
# slice into the already-computed result
start, end = int(topk_offsets[b]), int(topk_offsets[b + 1])
topk_indices[start:end] = indices_all[b].narrow(dim, 0, k_b).flatten()
if return_values:
assert topk_values is not None
topk_values[start:end] = values_all[b].narrow(dim, 0, k_b).flatten()
return BatchTopK(topk_indices, topk_offsets, topk_values)
##########
# Slow path
# -- Sequences different length, run topk for each --
if dim == 0:
batch_seq_ks = seq_lens.clamp_max(k)
batch_out_sizes = batch_seq_ks * prod_extra_dims
else:
dim_size = tensor.size(dim)
batch_seq_ks = k.clamp_max(dim_size)
batch_out_sizes = batch_seq_ks * seq_lens * prod_extra_dims
topk_offsets = seq_lengths_to_batch_offsets(batch_out_sizes)
assert isinstance(topk_offsets, Tensor)
# Allocate result tensor(s)
topk_indices = torch.empty(
int(topk_offsets[-1].item()), dtype=torch.long, device=tensor.device
)
topk_values: Optional[Tensor] = None
if return_values:
topk_values = tensor.new_empty(int(topk_offsets[-1]))
# Allocate a "scratch" buffer to hold the topk values outputs
max_seq_len = int(seq_lens.max().item())
max_k = int(batch_seq_ks.max().item()) if dim == 0 else int(k.max().item())
scratch_shape = list(tensor.shape)
scratch_shape[dim] = max_k
if dim != 0:
scratch_shape[0] = max_seq_len
scratch_values = tensor.new_empty(scratch_shape)
# per-batch topk
for b, k_b in enumerate(batch_seq_ks):
k_b = int(k_b)
batch_start, batch_end = int(batch_offsets[b]), int(batch_offsets[b + 1])
slice_start, slice_end = int(topk_offsets[b]), int(topk_offsets[b + 1])
# slice of the big concatted sequence tensor that represents this subsequence
subseq_b = tensor[batch_start:batch_end]
# Set up the view into the output topk tensor
subseq_shape = list(subseq_b.shape)
subseq_shape[dim] = k_b
topk_inds_subseq_view = topk_indices[slice_start:slice_end].view(subseq_shape)
# Set up view into reusable scratch holder for topk values output
if dim == 0:
val_buffer = scratch_values[:k_b]
else:
val_buffer = scratch_values.narrow(0, 0, subseq_b.size(0))
val_buffer = val_buffer.narrow(dim, 0, k_b)
_topk_out(
subseq_b.detach(),
k_b,
dim=dim,
largest=largest,
sorted=sorted,
out_values=val_buffer,
out_indices=topk_inds_subseq_view,
)
if return_values:
# clone topk inds to save unmodified tensor for take_along_dim backward
values = torch.take_along_dim(subseq_b, topk_inds_subseq_view.clone(), dim)
assert topk_values is not None
topk_values[slice_start:slice_end] = values.reshape(-1)
if dim == 0:
topk_inds_subseq_view.add_(batch_start)
return BatchTopK(topk_indices, topk_offsets, topk_values)
unpack_batch_topk(result, batch_offsets, original_shape, dim=0)
Re-shape and localize the flattened indices
/values
inside BatchTopK
object.
Parameters: |
|
---|
Returns: |
|
---|
Source code in pytorch_sparse_utils/utils/batch_topk.py
@torch.jit.script
def unpack_batch_topk(
result: BatchTopK,
batch_offsets: Tensor,
original_shape: list[int],
dim: int = 0,
) -> tuple[list[Tensor], Optional[list[Tensor]]]:
"""
Re-shape and localize the flattened `indices`/`values` inside `BatchTopK` object.
Args:
result (BatchTopK): The object returned by :pyfunc:`batch_topk`.
batch_offsets (Tensor): Same offsets that were passed to `batch_topk`.
original_shape ([list[int]): ``tensor.shape`` of the concatenated tensor that
was given to `batch_topk`.
dim (int): Dimension along which top-k was computed (same value that was given
to `batch_topk`).
Returns:
indices_per_batch (list[Tensor]): List containing one tensor per input
sequence. Each tensor is the same shape as returned by a call to
torch.topk(subseq, ...) for that subsequence.
values_per_batch (Optional[list[Tensor]]): List containing one tensor
per input sequence. Like indices_per_batch, each tensor will be the
same shape as returned by a standalone call torch.topk(subseq, ...).
If `batch_topk` was originally called with `return_values=False`,
then `values_per_batch` will be None.
"""
# Normalize possibly negative dim
dim = dim if dim >= 0 else dim + len(original_shape)
indices_per_batch: list[Tensor] = []
values_per_batch: Optional[list[Tensor]] = [] if result.values is not None else None
# Compute the product of the other dims to determine subsequence topk size
prod_other_dims = 1
for i, s in enumerate(original_shape):
if i not in (0, dim):
prod_other_dims *= s
for b in range(batch_offsets.numel() - 1):
# Sub-range of the concatenated tensor
start, end = int(batch_offsets[b]), int(batch_offsets[b + 1])
seq_len_b = end - start
# Slice into the flattened top-k output
slice_start, slice_end = int(result.offsets[b]), int(result.offsets[b + 1])
idx_flat_global = result.indices[slice_start:slice_end]
# Convert to local coordinates when top-k was along the sequence dim
idx_flat_local = idx_flat_global - start if dim == 0 else idx_flat_global
# Derive k_b from the number of elements
if idx_flat_local.numel() == 0:
k_b = 0
elif dim == 0:
k_b = idx_flat_local.numel() // prod_other_dims
else:
k_b = idx_flat_local.numel() // (seq_len_b * prod_other_dims)
# Build the full output shape for this subsequence
out_shape = list(original_shape)
if dim == 0:
out_shape[0] = k_b
else:
out_shape[0] = seq_len_b
out_shape[dim] = k_b
# Reshape and store
indices_per_batch.append(idx_flat_local.view(out_shape))
if result.values is not None:
vals = result.values
assert vals is not None
vals_b = vals[slice_start:slice_end]
assert values_per_batch is not None
values_per_batch.append(vals_b.view(out_shape)) # type: ignore
return indices_per_batch, values_per_batch