Sparse multi-scale deformable attention

Overview

This implements a version of Multi-scale Deformable Attention (MSDeformAttention) adapted for sparse tensors.


SparseMSDeformableAttentionBlock

Bases: Module

A standard transformer block using Sparse Multi-Scale Deformable Attention.

This module encapsulates the SparseMSDeformableAttention layer within a typical transformer block structure. It includes a query input projection, the attention mechanism itself, an output projection with dropout, a residual connection, and layer normalization. The layer normalization can be applied either before (pre-norm) or after (post-norm) the main block operations.

This block is designed to be a plug-and-play component in a larger transformer architecture that operates on sparse, multi-scale feature maps, such as the encoder or decoder of a Deformable DETR-like model.

The current version of this module only supports spatially-2D data.

Parameters:
  • embed_dim (int) –

    The embedding dimension for the queries and features.

  • n_heads (int) –

    The number of attention heads.

  • n_levels (int) –

    The number of feature levels to sample from.

  • n_points (int) –

    The number of sampling points per head per level.

  • dropout (float, default: 0.0 ) –

    Dropout probability for the output projection. Defaults to 0.0.

  • bias (bool, default: False ) –

    Whether to include bias terms in the input and output projection layers. Defaults to False.

  • norm_first (bool, default: True ) –

    If True, applies layer normalization before the attention and projection (pre-norm). If False, applies it after the residual connection (post-norm). Defaults to True.

forward(query, query_spatial_positions, query_batch_offsets, stacked_feature_maps, level_spatial_shapes, background_embedding=None, query_level_indices=None)

Forward pass for the SparseMSDeformableAttentionBlock.

Parameters:
  • query (Tensor) –

    Batch-flattened query tensor of shape [n_query, embed_dim].

  • query_spatial_positions (Tensor) –

    Spatial positions of queries, shape [n_queries, 2]. The positions must be floating-point values scaled to the feature level in which each query resides.

  • query_batch_offsets (Tensor) –

    Tensor of shape [batch_size+1] indicating the start and end indices for each batch item in the flattened query.

  • stacked_feature_maps (Tensor) –

    A sparse tensor containing feature maps from all levels, with shape [batch, height, width, levels, embed_dim]. The last dimension is dense, others are sparse.

  • level_spatial_shapes (Tensor) –

    Spatial dimensions (height, width) of each feature level, shape [n_levels, 2].

  • background_embedding (Optional[Tensor], default: None ) –

    An embedding to use for sampling points that fall in unspecified regions of the sparse feature maps. Shape [batch, n_levels, embed_dim].

  • query_level_indices (Optional[Tensor], default: None ) –

    The level index for each query, shape [n_queries]. If None, queries are assumed to be at the largest feature level.

Returns:
  • Tensor( Tensor ) –

    The output tensor after the attention block, with the same shape as the input query, [n_query, embed_dim].

reset_parameters()

Resets the parameters of all submodules.


SparseMSDeformableAttention

Bases: Module

An nn.Module for Multi-Scale Deformable Attention on sparse feature maps.

This module implements the attention mechanism described in "Deformable DETR". Instead of attending to all features in a dense feature map, each query learns to sample a small, fixed number of points (n_points) from multiple feature levels (n_levels). The locations of these sampling points are predicted as offsets from the query's reference position.

This implementation is adapted to work with torch.sparse_coo_tensors as the input feature maps. It uses a custom bilinear interpolation function to efficiently sample values from the sparse feature maps at the predicted locations.

The current version of this module only supports spatially-2D data.

The module contains learnable parameters for: - A value projection (value_proj) applied to the input feature maps. - A linear layer (sampling_offsets) to predict the 2D offsets for each sampling point. - A linear layer (attention_weights) to predict the weight of each sampled point. - A final output projection (output_proj).

Parameters:
  • embed_dim (int) –

    The embedding dimension of the input and output features.

  • n_heads (int) –

    The number of attention heads.

  • n_levels (int, default: 4 ) –

    The number of feature levels to sample from.

  • n_points (int, default: 4 ) –

    The number of sampling points per head per level.

forward(query, query_spatial_positions, query_batch_offsets, stacked_feature_maps, level_spatial_shapes, query_level_indices=None, background_embedding=None)

Forward function for SparseMSDeformableAttention.

Parameters:
  • query (Tensor) –

    Batch-flattened query tensor of shape [n_query x embed_dim]

  • query_batch_offsets (Tensor) –

    Tensor of shape [batch_size+1] with values such that item i is the start of batch i and item i+1 is the end of batch i in the query tensor.

  • query_spatial_positions (Tensor) –

    Spatial positions of queries, shape [n_queries, position_dim]. The positions must be floating-point values scaled in the shape of the feature level in which that query resides, as specified by query_level_indices.

  • stacked_feature_maps (Tensor) –

    Sparse tensor containing feature maps stacked across all levels, with total shape [batch, height, width, levels, embed_dim], where the last dimension is dense and the others are sparse.

  • level_spatial_shapes (Tensor) –

    Spatial dimensions of each level, shape [n_levels, 2]. Contains the height and width of feature maps at each resolution level.

  • query_level_indices (Optional[Tensor], default: None ) –

    Level indices of each query, shape [n_queries]. If None, it defaults to every query being from the level of the maximum spatial shape. This value should be specified in the encoder, where queries are tokens at various levels, but may be unspecified in the decoder, where queries are the object queries that are given as being at the full-scale level.

  • background_embedding (Optional[Tensor], default: None ) –

    Tensor of shape (batch, n_levels, embed_dim) that should be used as an interpolant for points that are not specified in stacked_feature_maps. If not given, a 0 vector will be used instead.

Returns:
  • Tensor( Tensor ) –

    Output embeddings after sparse deformable attention, shape [n_queries, embed_dim].

reset_parameters()


Utilities

sparse_split_heads(sparse_tensor, n_heads)

Splits a sparse tensor into multiple heads.

Parameters:
  • sparse_tensor (Tensor) –

    The input sparse tensor.

  • n_heads (int) –

    The number of heads to split into.

Returns:
  • Tensor( Tensor ) –

    The split sparse tensor with shape (*sparse_tensor.shape[:-1], n_heads, head_dim).

multilevel_sparse_bilinear_grid_sample(sparse_tensor, spatial_positions, batch_indices, level_spatial_shapes, level_indices=None, head_indices=None, background_embedding=None)

Bilinearly samples into a 2D sparse tensor. Similar to F.grid_sample with align_corners=False except the sampled tensor is expected to be sparse, the points are not in a grid, and the sampling points are expected to be in absolute coordinates instead of normalized to [-1, -1].

Note that this function uses a coordinate system that places coordinate (i, j) at the upper left corner of pixel [i, j]. The center of the pixel [i, j] is thus at coordinate (i+0.5, j+0.5). Interpolation is done from the 4 closest pixel centers to each sampling point in spatial_positions.

The batch size (number of images), number of feature levels, and number of attention heads are inferred from the shape of sparse_tensor.

Parameters:
  • sparse_tensor (Tensor) –

    torch.sparse.sparse_coo_tensor with shape (batch_size, height, width, n_levels, n_heads, head_dim), with the last dimension being dense and other dimensions sparse.

  • spatial_positions (Tensor) –

    Sampling point coordinates, with shape (N, ..., L, H, 2), with the last dimension in order (i, j), as in sparse_tensor, N being a batch dimension, L and H being level and head dimensions, respectively, and ... being additional optional batch dimensions.

  • batch_indices (Tensor) –

    Tensor of shape (N, ...) with the index of the batch element (image index) for each point in spatial_positions

  • level_spatial_shapes (Tensor) –

    (n_levels, 2) tensor with the (height, width) of each level's feature map

  • level_indices (Optional[Tensor], default: None ) –

    Tensor of shape (L) with the index of the level for each point in spatial_positions. If None, it defaults to torch.arange(n_levels).

  • head_indices (Optional[Tensor], default: None ) –

    Tensor of shape (H) with the index of the head to be sampled for each point in spatial_positions. If None, it defaults to torch.arange(n_heads).

  • background_embedding (Optional[Tensor], default: None ) –

    Tensor of shape (batch, n_levels, n_heads, head_dim) that should be used as an interpolant for points that are not specified in sparse_tensor. If not given, a 0 vector will be used instead.

Returns:
  • Tensor( Tensor ) –

    Bilinear interpolated tensor of shape (N, ..., L, H, head_dim).