Multi-level sparse self-attention

Overview

The self-attention implementation is intended for use with torch.sparse_coo_tensor multi-level feature maps. It uses RoPEEncodingND from nd-rotary-encodings to encode the positions and feature levels of all input points.


MultilevelSelfAttentionBlockWithRoPE

Bases: Module

Self-attention block for multi-level feature maps with rotary position encoding.

This module applies self-attention across tokens from multiple resolution levels, using Rotary Position Encodings (RoPE) to encode the spatial positions of tokens.

This module is meant as a plug-and-play standard Multi-head Attention Transformer block, encapsulating the input projection, RoPE encoding, attention operation, output operation with optional dropout, and layer normalization in either pre-norm or post-norm configurations.

Parameters:
  • embed_dim (int) –

    Dimensionality of input and output embeddings.

  • n_heads (int) –

    Number of attention heads.

  • position_dim (int, default: 2 ) –

    Dimensionality of spatial positions. Default: 2 (for 2D positions).

  • dropout (float, default: 0.0 ) –

    Dropout probability for attention weights and output projection. Default: 0.0.

  • bias (bool, default: False ) –

    Whether to use bias in linear layers. Default: False.

  • norm_first (bool, default: True ) –

    Whether to apply layer normalization before attention. Default: True.

  • rope_spatial_base_theta (float, default: 100.0 ) –

    Base theta value for RoPE spatial dimensions. Default: 100.0.

  • rope_level_base_theta (float, default: 10.0 ) –

    Base theta value for RoPE level dimension. Default: 10.0.

  • rope_share_heads (bool, default: False ) –

    Whether to share RoPE frequencies across attention heads. Default: False.

  • rope_freq_group_pattern (str, default: 'single' ) –

    Pattern to use for grouping RoPE frequencies. Options: "single", "partition", "closure". Default: "single".

  • rope_enforce_freq_groups_equal (boolean, default: True ) –

    Passed to the RoPE encoding module. Determines whether it will throw an error if the selected rope_freq_group_pattern has a number of frequency groups that does not evenly divide the RoPE encoding dimensions. Default: True

forward(x, spatial_positions, level_indices, level_spatial_shapes, batch_offsets, attn_mask=None)

Forward pass of multi-level self-attention with RoPE.

Parameters:
  • x (Tensor) –

    Input embeddings of shape (stacked_sequence_length, embed_dim). Contains embeddings from all batches concatenated together.

  • spatial_positions (Tensor) –

    Spatial positions of each token, shape (stacked_sequence_length, position_dim). These are expected to be in the original coordinate space of their respective levels, NOT normalized to [0, 1] range.

  • level_indices (Tensor) –

    Level index for each token, shape (stacked_sequence_length, ).

  • level_spatial_shapes (Tensor) –

    Spatial dimensions of each level, shape (num_levels, position_dim). Contains the height and width of feature maps at each resolution level.

  • batch_offsets (Tensor) –

    Tensor of shape (batch_size+1, ) indicating where each batch starts in the stacked sequence.

  • attn_mask (Optional[Tensor], default: None ) –

    Optional attention mask of shape (batch, seq_len, seq_len), (batch*n_heads, seq_len, seq_len), or (batch, n_heads, seq_len, seq_len), where True indicates the corresponding query/key product should be masked out.

Returns:
  • Tensor( Tensor ) –

    Output embeddings with same shape as input x.

Raises:
  • ValueError

    If tensor shapes are incompatible or position dimensions don't match.

reset_parameters()

Resets parameters to default initializations.