Sparse multi-scale deformable attention¶
Overview¶
This implements a version of Multi-scale Deformable Attention (MSDeformAttention) adapted for sparse tensors.
SparseMSDeformableAttentionBlock
¶
Bases: Module
A standard transformer block using Sparse Multi-Scale Deformable Attention.
This module encapsulates the SparseMSDeformableAttention
layer within a
typical transformer block structure. It includes a query input projection,
the attention mechanism itself, an output projection with dropout, a residual
connection, and layer normalization. The layer normalization can be applied
either before (pre-norm) or after (post-norm) the main block operations.
This block is designed to be a plug-and-play component in a larger transformer architecture that operates on sparse, multi-scale feature maps, such as the encoder or decoder of a Deformable DETR-like model.
The current version of this module only supports spatially-2D data.
Parameters: |
|
---|
forward(query, query_spatial_positions, query_batch_offsets, stacked_feature_maps, level_spatial_shapes, background_embedding=None, query_level_indices=None)
¶
Forward pass for the SparseMSDeformableAttentionBlock.
Parameters: |
|
---|
Returns: |
|
---|
reset_parameters()
¶
Resets the parameters of all submodules.
SparseMSDeformableAttention
¶
Bases: Module
An nn.Module for Multi-Scale Deformable Attention on sparse feature maps.
This module implements the attention mechanism described in "Deformable DETR".
Instead of attending to all features in a dense feature map, each query learns
to sample a small, fixed number of points (n_points
) from multiple feature
levels (n_levels
). The locations of these sampling points are predicted as
offsets from the query's reference position.
This implementation is adapted to work with torch.sparse_coo_tensor
s
as the input feature maps. It uses a custom bilinear interpolation function
to efficiently sample values from the sparse feature maps at the predicted
locations.
The current version of this module only supports spatially-2D data.
The module contains learnable parameters for:
- A value projection (value_proj
) applied to the input feature maps.
- A linear layer (sampling_offsets
) to predict the 2D offsets for each
sampling point.
- A linear layer (attention_weights
) to predict the weight of each sampled
point.
- A final output projection (output_proj
).
Parameters: |
|
---|
forward(query, query_spatial_positions, query_batch_offsets, stacked_feature_maps, level_spatial_shapes, query_level_indices=None, background_embedding=None)
¶
Forward function for SparseMSDeformableAttention.
Parameters: |
|
---|
Returns: |
|
---|
reset_parameters()
¶
Utilities¶
sparse_split_heads(sparse_tensor, n_heads)
¶
Splits a sparse tensor into multiple heads.
Parameters: |
|
---|
Returns: |
|
---|
multilevel_sparse_bilinear_grid_sample(sparse_tensor, spatial_positions, batch_indices, level_spatial_shapes, level_indices=None, head_indices=None, background_embedding=None)
¶
Bilinearly samples into a 2D sparse tensor. Similar to F.grid_sample with align_corners=False except the sampled tensor is expected to be sparse, the points are not in a grid, and the sampling points are expected to be in absolute coordinates instead of normalized to [-1, -1].
Note that this function uses a coordinate system that places coordinate (i, j) at
the upper left corner of pixel [i, j]. The center of the pixel [i, j] is thus at
coordinate (i+0.5, j+0.5). Interpolation is done from the 4 closest pixel centers
to each sampling point in spatial_positions
.
The batch size (number of images), number of feature levels, and number of attention
heads are inferred from the shape of sparse_tensor
.
Parameters: |
|
---|
Returns: |
|
---|