Multi-level sparse neighborhood attention¶
Overview¶
The multi-level sparse neighborhood attention operation allows query points to attend to the small neighborhoods of nonzero points around their spatial position, one neighborhood for each feature level. This is a potentially useful alternative or complement to multi-scale deformable attention, which can potentially try to sample from zero points on sparse tensors. The neighborhood attention operation, on the other hand, will always attend to all nonzero points within the given neighborhood sizes.
The neighborhood attention implementation makes use of a custom autograd operator that checkpoints the key and value projections of the neighborhood points and manually calculates the backward pass. This checkpointing is essential for memory management, particularly for operations with many potential query points such as within a DETR encoder, or a DETR decoder with many object queries.
SparseNeighborhoodAttentionBlock
¶
Bases: Module
Sparse neighborhood attention block for multi-level feature maps.
This module performs sparse attention over local neighborhoods at multiple resolution levels. Each query attends only to keys in its spatial neighborhood, with configurable neighborhood sizes at different resolution levels. This enables hierarchical feature aggregation while maintaining computational efficiency through sparse attention.
The implementation uses rotary position encodings (RoPE) to incorporate spatial and level position information into the attention mechanism. This allows the attention to be aware of relative spatial relationships between queries and keys.
Parameters: |
|
---|
forward(query, query_spatial_positions, query_batch_offsets, stacked_feature_maps, level_spatial_shapes, background_embedding=None, query_level_indices=None, query_mask=None)
¶
Forward pass of sparse neighborhood attention.
For each query, computes multi-head attention over keys in its spatial
neighborhood at multiple resolution levels. The neighborhood size at each
level is determined by the corresponding value in neighborhood_sizes
.
Parameters: |
|
---|
Returns: |
|
---|
Raises: |
|
---|
reset_parameters()
¶
Initializes/resets the weights of all submodules.
get_multilevel_neighborhoods(query_fullscale_spatial_positions, level_spatial_shapes, neighborhood_sizes=[3, 5, 7, 9])
¶
Computes multi-resolution neighborhood indices for query positions.
Generates neighborhood indices at multiple resolution levels for each query position, with configurable neighborhood sizes for each level. This enables hierarchical feature aggregation by defining sampling regions around each query point at different scales.
Parameters: |
|
---|
Returns: |
|
---|
Raises: |
|
---|