Multi-level sparse neighborhood attention

Overview

The multi-level sparse neighborhood attention operation allows query points to attend to the small neighborhoods of nonzero points around their spatial position, one neighborhood for each feature level. This is a potentially useful alternative or complement to multi-scale deformable attention, which can potentially try to sample from zero points on sparse tensors. The neighborhood attention operation, on the other hand, will always attend to all nonzero points within the given neighborhood sizes.

The neighborhood attention implementation makes use of a custom autograd operator that checkpoints the key and value projections of the neighborhood points and manually calculates the backward pass. This checkpointing is essential for memory management, particularly for operations with many potential query points such as within a DETR encoder, or a DETR decoder with many object queries.


SparseNeighborhoodAttentionBlock

Bases: Module

Sparse neighborhood attention block for multi-level feature maps.

This module performs sparse attention over local neighborhoods at multiple resolution levels. Each query attends only to keys in its spatial neighborhood, with configurable neighborhood sizes at different resolution levels. This enables hierarchical feature aggregation while maintaining computational efficiency through sparse attention.

The implementation uses rotary position encodings (RoPE) to incorporate spatial and level position information into the attention mechanism. This allows the attention to be aware of relative spatial relationships between queries and keys.

Parameters:
  • embed_dim (int) –

    Dimensionality of input and output embeddings.

  • n_heads (int) –

    Number of attention heads.

  • n_levels (int, default: 4 ) –

    Number of resolution levels. Default: 4

  • neighborhood_sizes (Union[Tensor, list[int]], default: [3, 5, 7, 9] ) –

    List of odd integers specifying the neighborhood size (window width) at each level. Default: [3, 5, 7, 9]

  • position_dim (int, default: 2 ) –

    Dimensionality of spatial positions. Default: 2 (for 2D positions).

  • dropout (float, default: 0.0 ) –

    Dropout probability for attention weights and output projection. Default: 0.0.

  • bias (bool, default: False ) –

    Whether to use bias in linear projections. Default: False.

  • norm_first (bool, default: True ) –

    Whether to apply layer normalization before attention. Default: True.

  • rope_spatial_base_theta (float, default: 100.0 ) –

    Base theta value for RoPE spatial dimensions. Larger values result in lower-frequency rotations, suitable for dimensions with greater spatial scale. Default: 100.0.

  • rope_level_base_theta (float, default: 10.0 ) –

    Base theta value for RoPE level dimension. Default: 10.0.

  • rope_share_heads (bool, default: False ) –

    Whether to share RoPE frequencies across attention heads. Default: False.

  • rope_freq_group_pattern (str, default: 'single' ) –

    Pattern to use for grouping RoPE frequencies. Options: "single", "partition", "closure". Default: "single".

  • rope_enforce_freq_groups_equal (bool, default: True ) –

    Whether to enforce equal division of frequency dimensions across frequency groups. Default: True.

forward(query, query_spatial_positions, query_batch_offsets, stacked_feature_maps, level_spatial_shapes, background_embedding=None, query_level_indices=None, query_mask=None)

Forward pass of sparse neighborhood attention.

For each query, computes multi-head attention over keys in its spatial neighborhood at multiple resolution levels. The neighborhood size at each level is determined by the corresponding value in neighborhood_sizes.

Parameters:
  • query (Tensor) –

    Query features of shape [n_queries, embed_dim].

  • query_spatial_positions (Tensor) –

    Spatial positions of queries, shape [n_queries, position_dim]. The positions must be floating-point values scaled in the range of the highest-resolution of the spatial shapes. This function will error if the positions are integers. If you have integer positions (i.e, position indices), use prep_multilevel_positions to get full-resolution decimal positions.

  • query_batch_offsets (Tensor) –

    Tensor of shape [batch_size+1] indicating where each batch starts in the queries.

  • stacked_feature_maps (Tensor) –

    Sparse tensor containing feature maps stacked across all levels, with total shape [batch, *spatial_dims, levels, embed_dim], where the last dimension is dense and the others are sparse.

  • level_spatial_shapes (Tensor) –

    Spatial dimensions of each level, shape [n_levels, position_dim]. Contains the height, width, etc. of feature maps at each resolution level.

  • background_embedding (Optional[Tensor], default: None ) –

    Optional tensor of shape [batch_size, n_levels, embed_dim] to serve as a background embedding. If given, then neighborhood indices that are not specified in stacked_feature_maps will be given the corresponding background embedding for that batch and level. If not given, then these keys will be masked out from the queries.

  • query_level_indices (Optional[Tensor], default: None ) –

    Level indices of each query, shape [n_queries]. If None, it defaults to every query being from the level of the maximum spatial shape. This value should be specified in the encoder, where queries are tokens at various levels, but may be unspecified in the decoder, where queries are the object queries that are given as being at the full-scale level.

  • query_mask (Optional[Tensor], default: None ) –

    Optional boolean tensor of shape [n_queries] that indicates queries that should not participate in the operation. Specifically, if present, positions where this tensor is True will have the corresponding query masked out from all keys in the attention operation, meaning the query vectors will be unmodified by the attention+residual operation.

Returns:
  • Tensor( Tensor ) –

    Output embeddings after neighborhood attention, shape [n_queries, embed_dim].

Raises:
  • ValueError

    If input tensors don't have expected shapes, or if query_spatial_positions is an integer tensor.

reset_parameters()

Initializes/resets the weights of all submodules.


get_multilevel_neighborhoods(query_fullscale_spatial_positions, level_spatial_shapes, neighborhood_sizes=[3, 5, 7, 9])

Computes multi-resolution neighborhood indices for query positions.

Generates neighborhood indices at multiple resolution levels for each query position, with configurable neighborhood sizes for each level. This enables hierarchical feature aggregation by defining sampling regions around each query point at different scales.

Parameters:
  • query_fullscale_spatial_positions (Tensor) –

    Query positions of shape [n_queries, position_dim], where each row contains the N-D position of a query point at the full scale resolution.

  • level_spatial_shapes (Tensor) –

    Tensor of shape [num_levels, position_dim] specifying the spatial dimensions of each resolution level.

  • neighborhood_sizes (Union[Tensor, list[int]], default: [3, 5, 7, 9] ) –

    List or tensor of odd integers specifying the neighborhood size (window width) at each level. Default: [3, 5, 7, 9].

Returns:
  • multilevel_neighborhood_indices( Tensor ) –

    Tensor of shape [n_queries, sum(neighborhood_sizes^position_dim), position_dim] containing the spatial indices of all neighborhood points for each query across all levels.

  • out_of_bounds_mask( Tensor ) –

    Boolean tensor of shape [n_queries, sum(neighborhood_sizes^position_dim)] that is True at locations in multilevel_neighborhood_indices that are out of bounds; i.e. negative or >= the spatial shape for that level If some of the computed neighborhood indices for a query are out of bounds of the level's spatial shape, those indices will instead be filled with mask values of -1.

  • level_indices( Tensor ) –

    Tensor of shape [sum(neighborhood_sizes^position_dim)] mapping each neighborhood position to its corresponding resolution level.

Raises:
  • ValueError

    If input tensors don't have the expected shape or dimensions, or if any neighborhood size is not an odd number.