Bases: Module
Self-attention block for multi-level feature maps with rotary position encoding.
This module applies self-attention across tokens from multiple resolution levels,
using Rotary Position Encodings (RoPE) to encode the spatial positions of tokens.
This module is meant as a plug-and-play standard Multi-head Attention Transformer
block, encapsulating the input projection, RoPE encoding, attention operation,
output operation with optional dropout, and layer normalization in either pre-norm
or post-norm configurations.
Parameters: |
-
embed_dim
(int )
–
Dimensionality of input and output embeddings.
-
n_heads
(int )
–
Number of attention heads.
-
position_dim
(int , default:
2
)
–
Dimensionality of spatial positions.
Default: 2 (for 2D positions).
-
dropout
(float , default:
0.0
)
–
Dropout probability for attention weights and
output projection. Default: 0.0.
-
bias
(bool , default:
False
)
–
Whether to use bias in linear layers. Default: False.
-
norm_first
(bool , default:
True
)
–
Whether to apply layer normalization before
attention. Default: True.
-
rope_spatial_base_theta
(float , default:
100.0
)
–
Base theta value for RoPE spatial
dimensions. Default: 100.0.
-
rope_level_base_theta
(float , default:
10.0
)
–
Base theta value for RoPE level
dimension. Default: 10.0.
-
rope_share_heads
(bool , default:
False
)
–
Whether to share RoPE frequencies across
attention heads. Default: False.
-
rope_freq_group_pattern
(str , default:
'single'
)
–
Pattern to use for grouping RoPE
frequencies. Options: "single", "partition", "closure". Default: "single".
-
rope_enforce_freq_groups_equal
(boolean , default:
True
)
–
Passed to the RoPE encoding module.
Determines whether it will throw an error if the selected
rope_freq_group_pattern has a number of frequency groups that does not
evenly divide the RoPE encoding dimensions. Default: True
|
forward(x, spatial_positions, level_indices, level_spatial_shapes, batch_offsets, attn_mask=None)
Forward pass of multi-level self-attention with RoPE.
Parameters: |
-
x
(Tensor )
–
Input embeddings of shape (stacked_sequence_length, embed_dim).
Contains embeddings from all batches concatenated together.
-
spatial_positions
(Tensor )
–
Spatial positions of each token,
shape (stacked_sequence_length, position_dim).
These are expected to be in the original coordinate space of their
respective levels, NOT normalized to [0, 1] range.
-
level_indices
(Tensor )
–
Level index for each token,
shape (stacked_sequence_length, ).
-
level_spatial_shapes
(Tensor )
–
Spatial dimensions of each level,
shape (num_levels, position_dim). Contains the height and width
of feature maps at each resolution level.
-
batch_offsets
(Tensor )
–
Tensor of shape (batch_size+1, )
indicating where each batch starts in the stacked sequence.
-
attn_mask
(Optional[Tensor] , default:
None
)
–
Optional attention mask of shape
(batch, seq_len, seq_len), (batch*n_heads, seq_len, seq_len), or
(batch, n_heads, seq_len, seq_len), where True indicates the
corresponding query/key product should be masked out.
|
Returns: |
-
Tensor ( Tensor
) –
Output embeddings with same shape as input x.
|
Raises: |
-
ValueError
–
If tensor shapes are incompatible or position dimensions don't match.
|
reset_parameters()
Resets parameters to default initializations.