RoPEEncodingND

Overview

RoPEEncodingND is the main user-facing API. It is intended to be used in Transformer-type attention as an additional step between the initial Q/K/V in-projection and before the query-key product.

Basic usage example for self-attention:

from torch import nn
from nd_rotary_encodings import RoPEEncodingND


class RoPEAttention_nd(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        use_checkpointing: bool = False,
        forward_only: bool = False,
        inplace: bool = False,
    ):
        super().__init__()

        self.embed_dim = embed_dim
        self.num_heads = num_heads

        self.qkv = nn.Linear(embed_dim, embed_dim * 3)

        self.pos_encoding = RoPEEncodingND(
            position_dim=2,
            embed_dim=embed_dim,
            n_heads=num_heads,
            rope_base_theta=rope_theta,
            use_checkpointing=use_checkpointing,
            forward_only=forward_only,
            inplace=inplace,
        )

        self.dropout_p = dropout
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(dropout)

    def forward(self, x: Tensor, x_positions: Tensor):
        batch_size, seq_len, embed_dim = x.shape
        head_dim = embed_dim // self.num_heads
        q, k, v = self.qkv(x).chunk(3, dim=-1)

        q, k = self.pos_encoding(q, x_positions, k)

        q = q.reshape(batch_size, seq_len, self.num_heads, head_dim).permute(0, 2, 1, 3)
        k = k.reshape(batch_size, seq_len, self.num_heads, head_dim).permute(0, 2, 1, 3)
        v = v.reshape(batch_size, seq_len, self.num_heads, head_dim).permute(0, 2, 1, 3)

        x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout_p)

        x = x.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)

        x = self.proj(x)
        x = self.proj_drop(x)
        return x

RoPEEncodingND

Bases: Module

N-dimensional Rotary Position Embedding (RoPE) module.

Implements rotary position embeddings for arbitrary dimensional positional inputs. This module applies RoPE to queries and keys in attention mechanisms, enabling position-aware attention across N spatial dimensions.

Parameters:
  • position_dim (int) –

    Number of position dimensions (e.g., 2 for 2D, 3 for 3D).

  • embed_dim (int) –

    Total embedding dimension, must be divisible by n_heads.

  • n_heads (int) –

    Number of attention heads.

  • share_heads (bool, default: False ) –

    If True, then only one set of frequencies per frequency group is created, that is shared among all attention heads, similar to traditional 1D RoPE. Defaults to False.

  • freq_group_pattern (Optional[Tensor], default: None ) –

    Boolean tensor of shape [n_freq_groups, position_dim] defining frequency group inclusion. The (head_dim/2) elements of the RoPE encoding vector will be split among the frequency groups. If position i, j is True, then frequency group i includes position dimension j. If None, freq_group_pattern will default to an all-True tensor of shape [1, position_dim]; i.e., one frequency group with all position dimensions.

  • enforce_freq_groups_equal (boolean, default: True ) –

    If True, then this function will raise a ValueError if the (head_dim/2) available elements of the RoPE vector cannot be evenly split between the frequency groups. If False, then trailing frequency groups may have fewer RoPE encodings assigned to them.

  • rope_base_theta (Union[Tensor], float], default: 10.0 ) –

    Base value(s) for frequency scaling. Can be a single float (applied to all dimensions and frequency groups) or a 2D tensor of shape [n_freq_groups, position_dim], with either component dimension allowed to be 1 for broadcasting. Entries corresponding to non-included position dimensions in a frequency group will be ignored. Larger values of theta result in lower-frequency rotations, and may be more suitable for dimensions of greater spatial scale. Default: 10.0

  • use_checkpointing (bool, default: False ) –

    If True, An end-to-end custom autograd Function is used that recomputes the RoPE rotation tensor(s) from the positions and frequencies during the backward pass rather than storing it. Since this tensor scales with the sequence length and embedding dimension, recomputing it can potentially lead to large memory savings when the sequence length is large. Default: False

  • forward_only (bool, default: False ) –

    If True, additional optimizations involving in-place updates of intermediate tensors are applied. This mode is safe for forward passes but is incompatible with autograd. This option is also incompatible with use_checkpointing. Default: False.

  • inplace (bool, default: False ) –

    If True, the embeddings are rotated in-place (i.e., the tensor values are overwritten). Requires forward_only. Default: False.

  • dtype (dtype, default: float ) –

    Data type for the internal frequency parameters. Default: torch.float

Source code in nd_rotary_encodings/position_encoding_layer/rope_encoding_layer.py
class RoPEEncodingND(nn.Module):
    """N-dimensional Rotary Position Embedding (RoPE) module.

    Implements rotary position embeddings for arbitrary dimensional positional inputs.
    This module applies RoPE to queries and keys in attention mechanisms, enabling
    position-aware attention across N spatial dimensions.

    Args:
        position_dim (int): Number of position dimensions (e.g., 2 for 2D, 3 for 3D).
        embed_dim (int): Total embedding dimension, must be divisible by n_heads.
        n_heads (int): Number of attention heads.
        share_heads (bool): If True, then only one set of frequencies per frequency
            group is created, that is shared among all attention heads, similar to
            traditional 1D RoPE. Defaults to False.
        freq_group_pattern (Optional[Tensor]): Boolean tensor of shape
            [n_freq_groups, position_dim] defining frequency group inclusion. The
            (head_dim/2) elements of the RoPE encoding vector will be split among the
            frequency groups. If position i, j is True, then frequency group i
            includes position dimension j. If None, freq_group_pattern will default
            to an all-True tensor of shape [1, position_dim]; i.e., one frequency group
            with all position dimensions.
        enforce_freq_groups_equal (boolean): If True, then this function will raise
            a ValueError if the (head_dim/2) available elements of the RoPE vector
            cannot be evenly split between the frequency groups. If False, then
            trailing frequency groups may have fewer RoPE encodings assigned to them.
        rope_base_theta (Union[Tensor], float]): Base value(s) for frequency scaling.
            Can be a single float (applied to all dimensions and frequency groups)
            or a 2D tensor of shape [n_freq_groups, position_dim], with either
            component dimension allowed to be 1 for broadcasting. Entries corresponding
            to non-included position dimensions in a frequency group will be ignored.
            Larger values of theta result in lower-frequency rotations, and may be more
            suitable for dimensions of greater spatial scale. Default: 10.0
        use_checkpointing (bool): If True, An end-to-end custom autograd Function is used
            that recomputes the RoPE rotation tensor(s) from the positions and
            frequencies during the backward pass rather than storing it. Since this
            tensor scales with the sequence length and embedding dimension, recomputing
            it can potentially lead to large memory savings when the sequence length is
            large. Default: False
        forward_only (bool): If True, additional optimizations involving in-place
            updates of intermediate tensors are applied. This mode is safe for
            forward passes but is incompatible with autograd. This option is also
            incompatible with use_checkpointing. Default: False.
        inplace (bool): If True, the embeddings are rotated in-place (i.e., the tensor
            values are overwritten). Requires forward_only. Default: False.
        dtype (torch.dtype): Data type for the internal frequency parameters.
            Default: torch.float
    """

    if TYPE_CHECKING:
        freq_group_pattern: Tensor
        freq_pos_indices: Tensor
        freq_group_indices: Tensor
        freq_head_indices: Tensor
        freq_enc_indices: Tensor
        encoding_ranges: Tensor

    def __init__(
        self,
        position_dim: int,
        embed_dim: int,
        n_heads: int,
        share_heads: bool = False,
        freq_group_pattern: Optional[Tensor] = None,
        enforce_freq_groups_equal: bool = True,
        rope_base_theta: Union[Tensor, float, list[list[float]]] = 10.0,
        use_checkpointing: bool = False,
        forward_only: bool = False,
        inplace: bool = False,
        dtype=torch.float,
    ):
        """Initialize the module."""
        super().__init__()

        if inplace and not forward_only:
            raise ValueError("`inplace=True` requires `forward_only=True`")
        if forward_only and use_checkpointing:
            raise ValueError("`use_checkpointing` is incompatible with `forward_only`")

        self.embed_dim = embed_dim
        if embed_dim % n_heads != 0:
            raise ValueError(
                "Expected embed_dim to be divisible by n_heads, got "
                f"{embed_dim} and {n_heads}"
            )
        self.head_dim = embed_dim // n_heads
        if self.head_dim % 2 != 0:
            raise ValueError(
                f"Expected head_dim to be divisible by 2, got {self.head_dim}"
            )
        self.position_dim = position_dim
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.share_heads = share_heads
        self.use_checkpointing = use_checkpointing
        self.forward_only = forward_only
        self.inplace = inplace

        if freq_group_pattern is None:
            # default frequency group pattern: one group with all position dimensions
            freq_group_pattern = torch.ones(1, position_dim, dtype=torch.bool)
        freq_group_pattern = torch.as_tensor(freq_group_pattern, dtype=torch.bool)

        self.enforce_freq_groups_equal = enforce_freq_groups_equal

        self.validate_freq_group_pattern(freq_group_pattern)
        self.register_buffer("freq_group_pattern", freq_group_pattern)
        self.n_freq_groups = freq_group_pattern.size(0)

        self._base_theta = torch.as_tensor(rope_base_theta, dtype=dtype)
        self.dtype = dtype
        self._init_freq_param()

    def _init_freq_param(self):
        """Initialize the frequency parameters for the RoPE module.

        Creates and stores the frequency parameters as trainable parameters and
        precomputes the indices used to construct the full sparse RoPE frequency
        tensor.
        """
        effective_n_heads = self.n_heads if not self.share_heads else 1

        freqs, encoding_ranges = init_nd_freqs(
            self.position_dim,
            self.head_dim,
            effective_n_heads,
            self.freq_group_pattern,
            self.enforce_freq_groups_equal,
            self._base_theta,
            dtype=self.dtype,
        )
        self.validate_grouped_freqs(freqs, encoding_ranges)

        self.freqs = nn.ParameterList(freqs)
        self.register_buffer("encoding_ranges", encoding_ranges)

        # Precompute indices for grouped_rope_freqs_tensor
        indices_list = []

        for g, range in enumerate(encoding_ranges):
            range_start, range_end = (int(r) for r in range)
            range_size = range_end - range_start
            pos_dims = torch.nonzero(self.freq_group_pattern[g], as_tuple=True)[0]

            # if range_size == 0 or pos_dims.numel() == 0: # empty frequency group
            #     continue

            # Create indexing tensors for this frequency group
            # Order matches output tensor shape: [position_dim, n_freq_groups, n_heads, head_dim//2]
            pos_idx = pos_dims.view(-1, 1, 1).expand(-1, effective_n_heads, range_size)
            g_idx = torch.full(
                (pos_dims.size(0), effective_n_heads, range_size),
                g,
                dtype=torch.long,
                device=pos_dims.device,
            )
            head_idx = (
                torch.arange(effective_n_heads, device=pos_dims.device)
                .view(1, -1, 1)
                .expand(pos_dims.size(0), -1, range_size)
            )
            dim_idx = (
                torch.arange(range_start, range_end, device=pos_dims.device)
                .view(1, 1, -1)
                .expand(pos_dims.size(0), effective_n_heads, -1)
            )

            # Stack with dimension order matching output tensor
            indices = torch.stack(
                [
                    pos_idx.flatten(),
                    g_idx.flatten(),
                    head_idx.flatten(),
                    dim_idx.flatten(),
                ],
                dim=0,
            )
            indices_list.append(indices)

        # Concatenate all indices
        indices = torch.cat(indices_list, dim=1)

        # store indices for construction ofm freq tensor in forward pass
        pos_indices, group_indices, head_indices, enc_indices = indices.unbind(0)
        self.register_buffer("freq_pos_indices", pos_indices)
        self.register_buffer("freq_group_indices", group_indices)
        self.register_buffer("freq_head_indices", head_indices)
        self.register_buffer("freq_enc_indices", enc_indices)

    def validate_freq_group_pattern(self, freq_group_pattern: Tensor):
        if freq_group_pattern.ndim != 2:
            raise ValueError(
                "Expected 2D tensor for freq_group_pattern, got shape "
                f"{freq_group_pattern.size()}"
            )
        if freq_group_pattern.size(1) != self.position_dim:
            raise ValueError(
                "Expected second dimension of freq_group_pattern to have size equal to "
                f"position_dim, got freq_group_pattern shape {freq_group_pattern.size()} "
                f"and position_dim={self.position_dim}"
            )
        n_freq_groups = freq_group_pattern.size(0)
        half_head_dim = self.head_dim // 2
        remainder = half_head_dim % n_freq_groups

        if remainder > 0 and self.enforce_freq_groups_equal:
            raise ValueError(
                f"RoPE encodings ({half_head_dim}) not evenly divisible by frequency "
                f"groups ({n_freq_groups})"
            )

    def validate_grouped_freqs(self, freqs: list[Tensor], encoding_ranges: Tensor):
        # Validate number of frequency groups
        n_freq_groups = len(freqs)
        if self.freq_group_pattern.size(0) != n_freq_groups:
            raise ValueError(
                "Expected the first dimension of freq_group_pattern (shape: "
                f"{self.freq_group_pattern.shape}) to have size equal to the length of the"
                f"freqs list ({len(freqs)})"
            )

        # Validate head_dim is consistent
        half_head_dim_list = [freqs.size(2) for freqs in freqs]
        if len(set(half_head_dim_list)) != 1 and self.enforce_freq_groups_equal:
            raise ValueError(
                "Expected tensors in freqs to all have the same number of "
                f"RoPE encodings; got {half_head_dim_list}"
            )

        # Validate n_heads is consistent
        n_heads_list = [freqs.size(1) for freqs in freqs]
        n_heads_set = set(n_heads_list)
        if not (
            len(n_heads_set) == 1
            or (len(n_heads_set) == 2 and len(n_heads_set - set((1,))) == 1)
        ):
            raise ValueError(
                "Expected tensors in freqs to have number of attention heads "
                f"all equal and/or 1, got {n_heads_list}"
            )

        # Validate encoding ranges
        if encoding_ranges.size(0) != n_freq_groups:
            raise ValueError(
                "Expected first dim of encoding_ranges to be equal to n_freq_groups "
                f"({n_freq_groups}), got shape {encoding_ranges}"
            )

        if not (
            torch.all(encoding_ranges[:, 0] <= encoding_ranges[:, 1])
            and torch.all(encoding_ranges[:-1, 1] == encoding_ranges[1:, 0])
        ):
            raise ValueError(
                "Expected encoding_ranges to be a 2D tensor of contiguous, "
                f"non-overlapping slices, got {encoding_ranges}"
            )

    def forward(
        self,
        query: Tensor,
        query_pos: Tensor,
        key: Optional[Tensor] = None,
        key_pos: Optional[Tensor] = None,
    ) -> Union[Tensor, tuple[Tensor, Tensor]]:
        """Apply rotary position embeddings to query and optionally key tensors.

        Applies position-dependent rotations to query and key tensors based on
        their associated position information.

        Args:
            query (Tensor): Query tensor of shape [..., embed_dim].
            query_pos (Tensor): Position tensor for query of shape
                [..., position_dim]. The leading dimensions must match those of query.
                It is assumed that the positions are NOT normalized to the standard
                [0, 1] range and are instead the true positions.
            key (Optional[Tensor]): Key tensor of shape [..., embed_dim]. Default: None
            key_pos (Optional[Tensor]): Position tensor for key of shape
                [..., position_dim]. If None and key is provided, query_pos will be
                used. It is assumed that the positions are NOT normalized to the
                standard [0, 1] range and are instead the true positions. Default: None

        Returns:
            Union[Tensor, tuple[Tensor, Tensor]]:
                - If key is None: Rotated query tensor of same shape as input query.
                - If key is provided: Tuple of (rotated query, rotated key) tensors.

        Note:
            - For query/key embeddings with a regular grid structure, a default
                position grid may be obtained from the static method `position_grid`.

        Raises:
            ValueError: If the tensor shapes are incompatible.

        Warns:
            UserWarning: If position coordinates appear to be normalized
                (in [0,1] range).
        """

        self.shape_check(query, query_pos)
        if query_pos.numel() > 0 and query_pos.min() > 0.0 and query_pos.max() <= 1.0:
            warnings.warn(
                "Expected un-normalized (i.e., not inside [0,1]) coordinates "
                "for position but found potentially normalized coordinates. "
                "Did you accidentally pass in normalized coordinates?\n(Your coord "
                f"range: [{query_pos.min().item(), query_pos.max().item()}])",
                UserWarning,
            )
        if key_pos is not None:  # Check key if present
            assert key is not None
            self.shape_check(key, key_pos)

        # Construct full frequency tensor from component parameters
        freq_tensor = self.grouped_rope_freqs_tensor(self.freqs)

        # unstack heads
        query = query.reshape(query.shape[:-1] + (self.n_heads, self.head_dim))
        if key is not None:
            key = key.reshape(key.shape[:-1] + (self.n_heads, self.head_dim))

        key_rotated = None
        # select proper path
        if self.forward_only or self.use_checkpointing:
            # use end-to-end function
            if key is not None and key_pos is None:
                # query and key share same positions
                query_rotated, key_rotated = self.apply_rope_endtoend(
                    query, query_pos, freq_tensor, self.forward_only, self.inplace, key
                )
            else:
                query_rotated = self.apply_rope_endtoend(query, query_pos, freq_tensor, self.forward_only, self.inplace)
                if key is not None:
                    assert key_pos is not None
                    key_rotated = self.apply_rope_endtoend(
                        key, key_pos, freq_tensor, self.forward_only, self.inplace
                    )
        else:
            query_rot_vec = self.calculate_rope(query_pos, freq_tensor)
            query_rotated = self.rotate_embeddings(query, query_rot_vec)
            if key is not None:
                if key_pos is None:
                    key_rotated = self.rotate_embeddings(key, query_rot_vec)
                else:
                    key_rot_vec = self.calculate_rope(key_pos, freq_tensor)
                    key_rotated = self.rotate_embeddings(key, key_rot_vec)

        assert isinstance(query_rotated, Tensor)
        # stack heads back
        query_rotated = query_rotated.view(query_rotated.shape[:-2] + (self.embed_dim,))

        if key is None:
            return query_rotated

        assert isinstance(key_rotated, Tensor)
        key_rotated = key_rotated.view(key_rotated.shape[:-2] + (self.embed_dim,))

        return query_rotated, key_rotated

    @staticmethod
    def position_grid(
        embeddings_shape: Union[Sequence[int], Tensor],
        start_dim: int = 1,
        end_dim: int = -1,
        device: Optional[Union[str, torch.device]] = None,
        dtype: Optional[torch.dtype] = None,
    ) -> Tensor:
        """Generates a regularly-spaced grid of positions based on the input shape.

        This function may be used to generate a tensor of positions corresponding to
        the tensor indices of each element in the embeddings tensor. This is
        potentially useful for regularly-spaced queries and keys, such as embeddings
        corresponding to text tokens or image pixels. The exact form of the position
        grid tensor is torch.stack(torch.meshgrid(
            *[torch.arange(size) for size in embeddings_shape[start_dim:end_dim]],
            indexing="ij"
        ), dim=-1)

        Args:
            embeddings_shape (Sequence[int] | Tensor): The full shape of the embeddings tensor.
            start_dim (int, optional): Start index of the position dimensions in
                embeddings_shape, inclusive. Defaults to 1 (i.e., one batch dim).
            end_dim (int, optional): End index of the position dimensions in
                embeddings_shape, exclusive. Defaults to -1 (i.e., one feature dim).
            device (Optional[Union[str, torch.device]]): The device on which to create the
                tensor. Defaults to None (i.e., default device).
            dtype (Optional[torch.device]): The dtype for the created tensor. Defaults
                to None (i.e., default dtype).

        Returns:
            Tensor: Created position grid tensor, of shape
                [*embeddings_shape[start_dim:end_dim],
                 len(embeddings_shape[start_dim:end_dim])]
        """
        grid = torch.stack(
            torch.meshgrid(
                *[
                    torch.arange(int(size), device=device, dtype=dtype)
                    for size in embeddings_shape[start_dim:end_dim]
                ],
                indexing="ij",
            ),
            dim=-1,
        )
        return grid

    def grouped_rope_freqs_tensor(
        self,
        grouped_rope_freqs: Union[list[Tensor], nn.ParameterList],
    ) -> Tensor:
        """Use frequency group information to build the full RoPE frequency tensor that
        is multiplied by the positions to produce RoPE encodings.

        This function takes the per-group RoPE frequencies to construct the RoPE
        frequency tensor. The RoPE frequency tensor has shape
        [position_dim, n_freq_groups, n_heads, head_dim/2], and is zero at positions
        where a position dimension is not included in a frequency group. The
        frequencies are stored separately per frequency group in tensors of shape
        [position_dim_g, n_heads, group_encoding_dim] because each frequency group may
        have a different number of active position dimensions and/or assigned encoding
        dimensions.

        Args:
            grouped_rope_freqs (list[Tensor]): List of per-group frequency tensors, as
                generated by init_nd_freqs, each of shape
                [
                    position_dim_g,
                    n_heads,
                    {head_dim//(2 * n_freq_groups), head_dim//(2 * n_freq_groups) + 1}
                ],
                where position_dim_g is the number of position dimensions included in
                frequency group g.

        Returns:
            Tensor: RoPE frequency tensor of shape
                [position_dim, n_freq_groups, n_heads, head_dim/2] or
                [position_dim, n_freq_groups,       1, head_dim/2], with nonzero
                elements corresponding to position dimensions included in each
                frequency group. It may be passed to `calculate_rope` with the
                positions tensor to compute RoPE encodings.
        """
        if isinstance(grouped_rope_freqs, Tensor):
            grouped_rope_freqs = [grouped_rope_freqs]

        # Create output tensor
        rope_freqs = grouped_rope_freqs[0].new_zeros(
            self.position_dim,
            self.n_freq_groups,
            self.n_heads if not self.share_heads else 1,
            self.head_dim // 2,
        )

        values = torch.cat([fg.flatten() for fg in grouped_rope_freqs])

        rope_freqs.index_put_(
            (
                self.freq_pos_indices,
                self.freq_group_indices,
                self.freq_head_indices,
                self.freq_enc_indices,
            ),
            values,
        )

        return rope_freqs

    @staticmethod
    def calculate_rope(positions: Tensor, rope_freqs: Tensor) -> Tensor:
        """Creates rotation vectors from position coordinates and RoPE frequencies.

        Transforms positional information into rotation vectors for RoPE.

        Args:
            positions (Tensor): Position tensor of shape [..., position_dim].
            rope_freqs (Tensor): Frequency tensor for rotary encodings of shape
                [position_dim, n_freq_groups, n_heads, head_dim/2].

        Returns:
            Tensor: Real-valued positional encodings of shape
                [..., n_heads, head_dim/2].
        """
        return calculate_rope(positions.to(rope_freqs), rope_freqs)

    @staticmethod
    def rotate_embeddings(
        query_or_key: Tensor,
        rope_encoding: Tensor,
        forward_only: bool = False,
        inplace: bool = False,
    ) -> Tensor:
        """Applies rotary embeddings to query or key tensor using complex
        multiplication.

        Rotates the query or key tensor using the rotation vectors via complex
        multiplication.

        Args:
            query_or_key (Tensor): Query or key tensor of shape
                [..., n_heads, head_dim].
            rope_encoding (Tensor): Real-valued RoPE encoding tensor of shape
                [..., n_heads, head_dim/2].

        Returns:
            Tensor: Rotated query or key tensor of same shape as input query_or_key.
        """
        # Unsqueeze rope_encoding if needed
        dim_diff = query_or_key.ndim - rope_encoding.ndim
        if dim_diff > 0:
            rope_encoding = rope_encoding.view((1,) * dim_diff + rope_encoding.shape)
        if forward_only:
            return rotate_embeddings_forward_only(
                query_or_key, rope_encoding, inplace=inplace
            )
        return rotate_embeddings(query_or_key, rope_encoding)

    def apply_rope_endtoend(
        self,
        embeddings: Tensor,
        positions: Tensor,
        rope_freqs: Tensor,
        forward_only: bool = False,
        inplace: bool = False,
        key_embeddings: Optional[Tensor] = None
    ) -> Union[Tensor, tuple[Tensor, Tensor]]:
        """Memory-optimized calculation and application of RoPE from components.

        Uses an end-to-end function to calculate RoPE rotation vectors from
        position coordinates and RoPE frequencies, then apply them to query or key
        embeddings.
        Under the hood, this uses a custom autograd Function and explicit backward
        calculation to save memory, at the cost of recalculating the RoPE rotation
        vectors during the backward pass.

        Args:
            embeddings (Tensor): Embeddings tensor of shape
                [..., n_heads, head_dim].
            positions (Tensor): Position tensor of shape [..., position_dim].
            rope_freqs (Tensor): Frequency tensor for rotary encodings of shape
                [position_dim, n_freq_groups, n_heads, head_dim/2].

        Returns:
            Tensor: Rotated query or key tensor of same shape as input query_or_key.
        """
        if forward_only:
            return apply_rope_forward_only(
                embeddings, positions.to(rope_freqs), rope_freqs, inplace=inplace,
                self_attn_key_embeddings=key_embeddings
            )
        return apply_rope_checkpointed(
            embeddings, positions.to(rope_freqs), rope_freqs, key_embeddings=key_embeddings
        )

    def shape_check(self, query_or_key: Tensor, query_or_key_pos: Tensor):
        """Validates the shapes of query/key and their position tensors.

        Args:
            query_or_key (Tensor): Query or key tensor of shape [..., embed_dim].
            query_or_key_pos (Tensor): Position tensor of shape [..., position_dim].
                Must be broadcastable to the shape of query_or_key.

        Raises:
            ValueError: If tensor shapes are incompatible.
        """
        if not can_broadcast_shapes(
            query_or_key.shape[:-1], query_or_key_pos.shape[:-1]
        ):
            raise ValueError(
                "Expected leading dims of query_or_key_pos to be broadcastable to "
                "leading dims of query_or_key, but got shapes "
                f"{query_or_key_pos.shape} and {query_or_key.shape}, respectively."
            )
        if query_or_key.shape[-1] != self.embed_dim:
            raise ValueError(
                "Expected query_or_key to have last dim equal to embed_dim "
                f"(={self.embed_dim}), got {query_or_key.shape[-1]}"
            )
        if query_or_key_pos.shape[-1] != self.position_dim:
            raise ValueError(
                "Expected query_or_key_pos to have last dim equal to pos_dim "
                f"(={self.position_dim}), got {query_or_key_pos.shape[-1]}"
            )

    def reset_parameters(self):
        """Resets frequency parameters"""
        freqs, _ = init_nd_freqs(
            self.position_dim,
            self.head_dim,
            self.n_heads if not self.share_heads else 1,
            self.freq_group_pattern,
            self.enforce_freq_groups_equal,
            self._base_theta,
            dtype=self.dtype,
            device=self.freqs[0].device,
        )
        with torch.no_grad():
            for param, init in zip(self.freqs, freqs):
                param.copy_(init)

forward(query, query_pos, key=None, key_pos=None)

Apply rotary position embeddings to query and optionally key tensors.

Applies position-dependent rotations to query and key tensors based on their associated position information.

Parameters:
  • query (Tensor) –

    Query tensor of shape [..., embed_dim].

  • query_pos (Tensor) –

    Position tensor for query of shape [..., position_dim]. The leading dimensions must match those of query. It is assumed that the positions are NOT normalized to the standard [0, 1] range and are instead the true positions.

  • key (Optional[Tensor], default: None ) –

    Key tensor of shape [..., embed_dim]. Default: None

  • key_pos (Optional[Tensor], default: None ) –

    Position tensor for key of shape [..., position_dim]. If None and key is provided, query_pos will be used. It is assumed that the positions are NOT normalized to the standard [0, 1] range and are instead the true positions. Default: None

Returns:
  • Union[Tensor, tuple[Tensor, Tensor]]

    Union[Tensor, tuple[Tensor, Tensor]]: - If key is None: Rotated query tensor of same shape as input query. - If key is provided: Tuple of (rotated query, rotated key) tensors.

Note
  • For query/key embeddings with a regular grid structure, a default position grid may be obtained from the static method position_grid.
Raises:
  • ValueError

    If the tensor shapes are incompatible.

Warns:
  • UserWarning

    If position coordinates appear to be normalized (in [0,1] range).

Source code in nd_rotary_encodings/position_encoding_layer/rope_encoding_layer.py
def forward(
    self,
    query: Tensor,
    query_pos: Tensor,
    key: Optional[Tensor] = None,
    key_pos: Optional[Tensor] = None,
) -> Union[Tensor, tuple[Tensor, Tensor]]:
    """Apply rotary position embeddings to query and optionally key tensors.

    Applies position-dependent rotations to query and key tensors based on
    their associated position information.

    Args:
        query (Tensor): Query tensor of shape [..., embed_dim].
        query_pos (Tensor): Position tensor for query of shape
            [..., position_dim]. The leading dimensions must match those of query.
            It is assumed that the positions are NOT normalized to the standard
            [0, 1] range and are instead the true positions.
        key (Optional[Tensor]): Key tensor of shape [..., embed_dim]. Default: None
        key_pos (Optional[Tensor]): Position tensor for key of shape
            [..., position_dim]. If None and key is provided, query_pos will be
            used. It is assumed that the positions are NOT normalized to the
            standard [0, 1] range and are instead the true positions. Default: None

    Returns:
        Union[Tensor, tuple[Tensor, Tensor]]:
            - If key is None: Rotated query tensor of same shape as input query.
            - If key is provided: Tuple of (rotated query, rotated key) tensors.

    Note:
        - For query/key embeddings with a regular grid structure, a default
            position grid may be obtained from the static method `position_grid`.

    Raises:
        ValueError: If the tensor shapes are incompatible.

    Warns:
        UserWarning: If position coordinates appear to be normalized
            (in [0,1] range).
    """

    self.shape_check(query, query_pos)
    if query_pos.numel() > 0 and query_pos.min() > 0.0 and query_pos.max() <= 1.0:
        warnings.warn(
            "Expected un-normalized (i.e., not inside [0,1]) coordinates "
            "for position but found potentially normalized coordinates. "
            "Did you accidentally pass in normalized coordinates?\n(Your coord "
            f"range: [{query_pos.min().item(), query_pos.max().item()}])",
            UserWarning,
        )
    if key_pos is not None:  # Check key if present
        assert key is not None
        self.shape_check(key, key_pos)

    # Construct full frequency tensor from component parameters
    freq_tensor = self.grouped_rope_freqs_tensor(self.freqs)

    # unstack heads
    query = query.reshape(query.shape[:-1] + (self.n_heads, self.head_dim))
    if key is not None:
        key = key.reshape(key.shape[:-1] + (self.n_heads, self.head_dim))

    key_rotated = None
    # select proper path
    if self.forward_only or self.use_checkpointing:
        # use end-to-end function
        if key is not None and key_pos is None:
            # query and key share same positions
            query_rotated, key_rotated = self.apply_rope_endtoend(
                query, query_pos, freq_tensor, self.forward_only, self.inplace, key
            )
        else:
            query_rotated = self.apply_rope_endtoend(query, query_pos, freq_tensor, self.forward_only, self.inplace)
            if key is not None:
                assert key_pos is not None
                key_rotated = self.apply_rope_endtoend(
                    key, key_pos, freq_tensor, self.forward_only, self.inplace
                )
    else:
        query_rot_vec = self.calculate_rope(query_pos, freq_tensor)
        query_rotated = self.rotate_embeddings(query, query_rot_vec)
        if key is not None:
            if key_pos is None:
                key_rotated = self.rotate_embeddings(key, query_rot_vec)
            else:
                key_rot_vec = self.calculate_rope(key_pos, freq_tensor)
                key_rotated = self.rotate_embeddings(key, key_rot_vec)

    assert isinstance(query_rotated, Tensor)
    # stack heads back
    query_rotated = query_rotated.view(query_rotated.shape[:-2] + (self.embed_dim,))

    if key is None:
        return query_rotated

    assert isinstance(key_rotated, Tensor)
    key_rotated = key_rotated.view(key_rotated.shape[:-2] + (self.embed_dim,))

    return query_rotated, key_rotated

position_grid(embeddings_shape, start_dim=1, end_dim=-1, device=None, dtype=None) staticmethod

Generates a regularly-spaced grid of positions based on the input shape.

This function may be used to generate a tensor of positions corresponding to the tensor indices of each element in the embeddings tensor. This is potentially useful for regularly-spaced queries and keys, such as embeddings corresponding to text tokens or image pixels. The exact form of the position grid tensor is torch.stack(torch.meshgrid( *[torch.arange(size) for size in embeddings_shape[start_dim:end_dim]], indexing="ij" ), dim=-1)

Parameters:
  • embeddings_shape (Sequence[int] | Tensor) –

    The full shape of the embeddings tensor.

  • start_dim (int, default: 1 ) –

    Start index of the position dimensions in embeddings_shape, inclusive. Defaults to 1 (i.e., one batch dim).

  • end_dim (int, default: -1 ) –

    End index of the position dimensions in embeddings_shape, exclusive. Defaults to -1 (i.e., one feature dim).

  • device (Optional[Union[str, device]], default: None ) –

    The device on which to create the tensor. Defaults to None (i.e., default device).

  • dtype (Optional[device], default: None ) –

    The dtype for the created tensor. Defaults to None (i.e., default dtype).

Returns:
  • Tensor( Tensor ) –

    Created position grid tensor, of shape [*embeddings_shape[start_dim:end_dim], len(embeddings_shape[start_dim:end_dim])]

Source code in nd_rotary_encodings/position_encoding_layer/rope_encoding_layer.py
@staticmethod
def position_grid(
    embeddings_shape: Union[Sequence[int], Tensor],
    start_dim: int = 1,
    end_dim: int = -1,
    device: Optional[Union[str, torch.device]] = None,
    dtype: Optional[torch.dtype] = None,
) -> Tensor:
    """Generates a regularly-spaced grid of positions based on the input shape.

    This function may be used to generate a tensor of positions corresponding to
    the tensor indices of each element in the embeddings tensor. This is
    potentially useful for regularly-spaced queries and keys, such as embeddings
    corresponding to text tokens or image pixels. The exact form of the position
    grid tensor is torch.stack(torch.meshgrid(
        *[torch.arange(size) for size in embeddings_shape[start_dim:end_dim]],
        indexing="ij"
    ), dim=-1)

    Args:
        embeddings_shape (Sequence[int] | Tensor): The full shape of the embeddings tensor.
        start_dim (int, optional): Start index of the position dimensions in
            embeddings_shape, inclusive. Defaults to 1 (i.e., one batch dim).
        end_dim (int, optional): End index of the position dimensions in
            embeddings_shape, exclusive. Defaults to -1 (i.e., one feature dim).
        device (Optional[Union[str, torch.device]]): The device on which to create the
            tensor. Defaults to None (i.e., default device).
        dtype (Optional[torch.device]): The dtype for the created tensor. Defaults
            to None (i.e., default dtype).

    Returns:
        Tensor: Created position grid tensor, of shape
            [*embeddings_shape[start_dim:end_dim],
             len(embeddings_shape[start_dim:end_dim])]
    """
    grid = torch.stack(
        torch.meshgrid(
            *[
                torch.arange(int(size), device=device, dtype=dtype)
                for size in embeddings_shape[start_dim:end_dim]
            ],
            indexing="ij",
        ),
        dim=-1,
    )
    return grid

Utilities

prep_multilevel_positions(spatial_positions, batch_indices, level_indices, level_spatial_shapes)

Standardizes positional coordinates across multiple resolution levels.

Converts indices or positions from multiple resolution levels to a standardized coordinate system by rescaling each level to match the finest level's resolution. This enables consistent position encoding across hierarchical feature maps.

Parameters:
  • spatial_positions (Tensor) –

    Indices or positions of shape [..., position_dim], where each row contains the N-D position of each point. If floating point, they're treated as coordinates; if integer, they're treated as indices.

  • batch_indices (Tensor) –

    Integer tensor of shape [...], containing the batch index for each position in spatial_positions.

  • level_indices (Tensor) –

    Integer tensor of shape [...], containing the level index for each position in spatial_positions.

  • level_spatial_shapes (Tensor) –

    Tensor of shape [num_levels, 2] or [batch_size, num_levels, 2] specifying the spatial dimensions (height, width) of each level.

Returns:
  • Tensor( Tensor ) –

    Rescaled positions of shape [..., position_dim + 1] with floating point dtype, where the last dimension has the level index concatenated onto the end of the spatial coordinates, and the spatial coordinates are standardized to the finest resolution level.

Raises:
  • ValueError

    If tensors don't have the expected shape, dimensions, or dtypes.

Source code in nd_rotary_encodings/position_encoding_layer/utils.py
def prep_multilevel_positions(
    spatial_positions: Tensor,
    batch_indices: Tensor,
    level_indices: Tensor,
    level_spatial_shapes: Tensor,
) -> Tensor:
    """Standardizes positional coordinates across multiple resolution levels.

    Converts indices or positions from multiple resolution levels to a standardized
    coordinate system by rescaling each level to match the finest level's resolution.
    This enables consistent position encoding across hierarchical feature maps.

    Args:
        spatial_positions (Tensor): Indices or positions of shape [..., position_dim],
            where each row contains the N-D position of each point. If floating point,
            they're treated as coordinates; if integer, they're treated as indices.
        batch_indices (Tensor): Integer tensor of shape [...], containing the
            batch index for each position in spatial_positions.
        level_indices (Tensor): Integer tensor of shape [...], containing the
            level index for each position in spatial_positions.
        level_spatial_shapes (Tensor): Tensor of shape [num_levels, 2] or
            [batch_size, num_levels, 2] specifying the spatial dimensions
            (height, width) of each level.

    Returns:
        Tensor: Rescaled positions of shape [..., position_dim + 1] with floating
            point dtype, where the last dimension has the level index concatenated onto
            the end of the spatial coordinates, and the spatial coordinates are
            standardized to the finest resolution level.

    Raises:
        ValueError: If tensors don't have the expected shape, dimensions, or dtypes.
    """
    validate_atleast_nd(spatial_positions, 2, "spatial_positions")
    batch_dims = spatial_positions.ndim - 1
    validate_nd(batch_indices, batch_dims, "batch_indices")
    validate_nd(level_indices, batch_dims, "level_indices")

    if not torch.is_floating_point(spatial_positions):
        # convert from indices to coordinates of pixel centers
        spatial_positions = spatial_positions + 0.5

    # batch, level, pos_dim or level, pos_dim
    assert level_spatial_shapes.ndim in (2, 3)

    # Initialize output tensor
    multilevel_positions = spatial_positions.new_zeros(
        spatial_positions.shape[:-1] + (spatial_positions.size(-1) + 1,)
    )

    # Early exit
    if multilevel_positions.numel() == 0:
        return multilevel_positions

    if level_spatial_shapes.ndim == 2:
        level_spatial_shapes = level_spatial_shapes.unsqueeze(0).expand(
            int(torch.max(batch_indices).item()) + 1, -1, -1
        )

    batch_max_spatial_shape = level_spatial_shapes.max(-2)[0]
    max_spatial_shapes = batch_max_spatial_shape[batch_indices]
    indexed_spatial_shapes = level_spatial_shapes[batch_indices, level_indices]

    # Fill in rescaled positions
    multilevel_positions[..., :-1] = spatial_positions / (
        indexed_spatial_shapes / max_spatial_shapes
    )

    # Fill in level indices
    multilevel_positions[..., -1] = level_indices.to(multilevel_positions)

    return multilevel_positions

get_multilevel_freq_group_pattern(position_dim, pattern_name, device=None)

Get a predefined frequency group pattern for RoPE encodings of multilevel features.

Creates a frequency group pattern tensor for use with RoPEEncodingND based on predefined patterns that determine how spatial and level dimensions are encoded.

Parameters:
  • position_dim (int) –

    Spatial dimension of the features to be encoded (2 for 2D images, etc.). The output tensor will have this many spatial dimensions plus 1 dimension for the feature level

  • pattern_name (Union[str, FreqGroupPattern]) –

    Pattern to use, either as a string or enum value. Options: - "single" or FreqGroupPattern.SINGLE: All dimensions (spatial, level) in a single frequency group - "partition" or FreqGroupPattern.PARTITION: Spatial dimensions and level in separate groups - "closure" or FreqGroupPattern.CLOSURE: Three groups - Spatial, level, and (spatial, level)

  • device (device, default: None ) –

    Device for the created tensor. Defaults to None.

Returns:
  • Tensor( Tensor ) –

    Boolean tensor encoding the frequency group pattern, of shape [n_freq_groups, position_dim + 1]

Raises:
  • ValueError

    If an unrecognized pattern name is provided.

Source code in nd_rotary_encodings/position_encoding_layer/utils.py
def get_multilevel_freq_group_pattern(
    position_dim: int, pattern_name: Union[str, FreqGroupPattern], device=None
) -> Tensor:
    """Get a predefined frequency group pattern for RoPE encodings of multilevel features.

    Creates a frequency group pattern tensor for use with RoPEEncodingND based on
    predefined patterns that determine how spatial and level dimensions are encoded.

    Args:
        position_dim (int): Spatial dimension of the features to be encoded (2 for 2D
            images, etc.). The output tensor will have this many spatial dimensions
            plus 1 dimension for the feature level
        pattern_name (Union[str, FreqGroupPattern]): Pattern to use, either as a string
            or enum value. Options:
            - "single" or FreqGroupPattern.SINGLE: All dimensions (*spatial, level) in
                a single frequency group
            - "partition" or FreqGroupPattern.PARTITION: Spatial dimensions and level
                in separate groups
            - "closure" or FreqGroupPattern.CLOSURE: Three groups - Spatial, level,
                and (*spatial, level)
        device (torch.device, optional): Device for the created tensor. Defaults to None.

    Returns:
        Tensor: Boolean tensor encoding the frequency group pattern, of shape
            [n_freq_groups, position_dim + 1]

    Raises:
        ValueError: If an unrecognized pattern name is provided.
    """
    if isinstance(pattern_name, FreqGroupPattern):
        pattern_name = pattern_name.value

    if pattern_name == "single":
        out = torch.ones(1, position_dim + 1, device=device)
    elif pattern_name == "partition":
        out = torch.zeros(2, position_dim + 1, device=device)
        out[0, :-1] = True  # Spatial dimensions in one group
        out[1, -1] = True  # Level dimension in second group
    elif pattern_name == "closure":
        out = torch.zeros(3, position_dim + 1, device=device)
        out[0, :-1] = True  # Spatial dimensions in one group
        out[1, -1] = True  # Level dimension in second group
        out[2, :] = True  # Third group has all dimensions
    else:
        raise ValueError(f"Unrecognized pattern_name {pattern_name}")

    return out