RoPEEncodingND¶
Overview¶
RoPEEncodingND
is the main user-facing API. It is intended to be used in Transformer-type attention as an additional step between the initial Q/K/V in-projection and before the query-key product.
Basic usage example for self-attention:
from torch import nn
from nd_rotary_encodings import RoPEEncodingND
class RoPEAttention_nd(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
use_checkpointing: bool = False,
forward_only: bool = False,
inplace: bool = False,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.qkv = nn.Linear(embed_dim, embed_dim * 3)
self.pos_encoding = RoPEEncodingND(
position_dim=2,
embed_dim=embed_dim,
n_heads=num_heads,
rope_base_theta=rope_theta,
use_checkpointing=use_checkpointing,
forward_only=forward_only,
inplace=inplace,
)
self.dropout_p = dropout
self.proj = nn.Linear(embed_dim, embed_dim)
self.proj_drop = nn.Dropout(dropout)
def forward(self, x: Tensor, x_positions: Tensor):
batch_size, seq_len, embed_dim = x.shape
head_dim = embed_dim // self.num_heads
q, k, v = self.qkv(x).chunk(3, dim=-1)
q, k = self.pos_encoding(q, x_positions, k)
q = q.reshape(batch_size, seq_len, self.num_heads, head_dim).permute(0, 2, 1, 3)
k = k.reshape(batch_size, seq_len, self.num_heads, head_dim).permute(0, 2, 1, 3)
v = v.reshape(batch_size, seq_len, self.num_heads, head_dim).permute(0, 2, 1, 3)
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout_p)
x = x.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
x = self.proj(x)
x = self.proj_drop(x)
return x
RoPEEncodingND
¶
Bases: Module
N-dimensional Rotary Position Embedding (RoPE) module.
Implements rotary position embeddings for arbitrary dimensional positional inputs. This module applies RoPE to queries and keys in attention mechanisms, enabling position-aware attention across N spatial dimensions.
Parameters: |
|
---|
Source code in nd_rotary_encodings/position_encoding_layer/rope_encoding_layer.py
class RoPEEncodingND(nn.Module):
"""N-dimensional Rotary Position Embedding (RoPE) module.
Implements rotary position embeddings for arbitrary dimensional positional inputs.
This module applies RoPE to queries and keys in attention mechanisms, enabling
position-aware attention across N spatial dimensions.
Args:
position_dim (int): Number of position dimensions (e.g., 2 for 2D, 3 for 3D).
embed_dim (int): Total embedding dimension, must be divisible by n_heads.
n_heads (int): Number of attention heads.
share_heads (bool): If True, then only one set of frequencies per frequency
group is created, that is shared among all attention heads, similar to
traditional 1D RoPE. Defaults to False.
freq_group_pattern (Optional[Tensor]): Boolean tensor of shape
[n_freq_groups, position_dim] defining frequency group inclusion. The
(head_dim/2) elements of the RoPE encoding vector will be split among the
frequency groups. If position i, j is True, then frequency group i
includes position dimension j. If None, freq_group_pattern will default
to an all-True tensor of shape [1, position_dim]; i.e., one frequency group
with all position dimensions.
enforce_freq_groups_equal (boolean): If True, then this function will raise
a ValueError if the (head_dim/2) available elements of the RoPE vector
cannot be evenly split between the frequency groups. If False, then
trailing frequency groups may have fewer RoPE encodings assigned to them.
rope_base_theta (Union[Tensor], float]): Base value(s) for frequency scaling.
Can be a single float (applied to all dimensions and frequency groups)
or a 2D tensor of shape [n_freq_groups, position_dim], with either
component dimension allowed to be 1 for broadcasting. Entries corresponding
to non-included position dimensions in a frequency group will be ignored.
Larger values of theta result in lower-frequency rotations, and may be more
suitable for dimensions of greater spatial scale. Default: 10.0
use_checkpointing (bool): If True, An end-to-end custom autograd Function is used
that recomputes the RoPE rotation tensor(s) from the positions and
frequencies during the backward pass rather than storing it. Since this
tensor scales with the sequence length and embedding dimension, recomputing
it can potentially lead to large memory savings when the sequence length is
large. Default: False
forward_only (bool): If True, additional optimizations involving in-place
updates of intermediate tensors are applied. This mode is safe for
forward passes but is incompatible with autograd. This option is also
incompatible with use_checkpointing. Default: False.
inplace (bool): If True, the embeddings are rotated in-place (i.e., the tensor
values are overwritten). Requires forward_only. Default: False.
dtype (torch.dtype): Data type for the internal frequency parameters.
Default: torch.float
"""
if TYPE_CHECKING:
freq_group_pattern: Tensor
freq_pos_indices: Tensor
freq_group_indices: Tensor
freq_head_indices: Tensor
freq_enc_indices: Tensor
encoding_ranges: Tensor
def __init__(
self,
position_dim: int,
embed_dim: int,
n_heads: int,
share_heads: bool = False,
freq_group_pattern: Optional[Tensor] = None,
enforce_freq_groups_equal: bool = True,
rope_base_theta: Union[Tensor, float, list[list[float]]] = 10.0,
use_checkpointing: bool = False,
forward_only: bool = False,
inplace: bool = False,
dtype=torch.float,
):
"""Initialize the module."""
super().__init__()
if inplace and not forward_only:
raise ValueError("`inplace=True` requires `forward_only=True`")
if forward_only and use_checkpointing:
raise ValueError("`use_checkpointing` is incompatible with `forward_only`")
self.embed_dim = embed_dim
if embed_dim % n_heads != 0:
raise ValueError(
"Expected embed_dim to be divisible by n_heads, got "
f"{embed_dim} and {n_heads}"
)
self.head_dim = embed_dim // n_heads
if self.head_dim % 2 != 0:
raise ValueError(
f"Expected head_dim to be divisible by 2, got {self.head_dim}"
)
self.position_dim = position_dim
self.embed_dim = embed_dim
self.n_heads = n_heads
self.share_heads = share_heads
self.use_checkpointing = use_checkpointing
self.forward_only = forward_only
self.inplace = inplace
if freq_group_pattern is None:
# default frequency group pattern: one group with all position dimensions
freq_group_pattern = torch.ones(1, position_dim, dtype=torch.bool)
freq_group_pattern = torch.as_tensor(freq_group_pattern, dtype=torch.bool)
self.enforce_freq_groups_equal = enforce_freq_groups_equal
self.validate_freq_group_pattern(freq_group_pattern)
self.register_buffer("freq_group_pattern", freq_group_pattern)
self.n_freq_groups = freq_group_pattern.size(0)
self._base_theta = torch.as_tensor(rope_base_theta, dtype=dtype)
self.dtype = dtype
self._init_freq_param()
def _init_freq_param(self):
"""Initialize the frequency parameters for the RoPE module.
Creates and stores the frequency parameters as trainable parameters and
precomputes the indices used to construct the full sparse RoPE frequency
tensor.
"""
effective_n_heads = self.n_heads if not self.share_heads else 1
freqs, encoding_ranges = init_nd_freqs(
self.position_dim,
self.head_dim,
effective_n_heads,
self.freq_group_pattern,
self.enforce_freq_groups_equal,
self._base_theta,
dtype=self.dtype,
)
self.validate_grouped_freqs(freqs, encoding_ranges)
self.freqs = nn.ParameterList(freqs)
self.register_buffer("encoding_ranges", encoding_ranges)
# Precompute indices for grouped_rope_freqs_tensor
indices_list = []
for g, range in enumerate(encoding_ranges):
range_start, range_end = (int(r) for r in range)
range_size = range_end - range_start
pos_dims = torch.nonzero(self.freq_group_pattern[g], as_tuple=True)[0]
# if range_size == 0 or pos_dims.numel() == 0: # empty frequency group
# continue
# Create indexing tensors for this frequency group
# Order matches output tensor shape: [position_dim, n_freq_groups, n_heads, head_dim//2]
pos_idx = pos_dims.view(-1, 1, 1).expand(-1, effective_n_heads, range_size)
g_idx = torch.full(
(pos_dims.size(0), effective_n_heads, range_size),
g,
dtype=torch.long,
device=pos_dims.device,
)
head_idx = (
torch.arange(effective_n_heads, device=pos_dims.device)
.view(1, -1, 1)
.expand(pos_dims.size(0), -1, range_size)
)
dim_idx = (
torch.arange(range_start, range_end, device=pos_dims.device)
.view(1, 1, -1)
.expand(pos_dims.size(0), effective_n_heads, -1)
)
# Stack with dimension order matching output tensor
indices = torch.stack(
[
pos_idx.flatten(),
g_idx.flatten(),
head_idx.flatten(),
dim_idx.flatten(),
],
dim=0,
)
indices_list.append(indices)
# Concatenate all indices
indices = torch.cat(indices_list, dim=1)
# store indices for construction ofm freq tensor in forward pass
pos_indices, group_indices, head_indices, enc_indices = indices.unbind(0)
self.register_buffer("freq_pos_indices", pos_indices)
self.register_buffer("freq_group_indices", group_indices)
self.register_buffer("freq_head_indices", head_indices)
self.register_buffer("freq_enc_indices", enc_indices)
def validate_freq_group_pattern(self, freq_group_pattern: Tensor):
if freq_group_pattern.ndim != 2:
raise ValueError(
"Expected 2D tensor for freq_group_pattern, got shape "
f"{freq_group_pattern.size()}"
)
if freq_group_pattern.size(1) != self.position_dim:
raise ValueError(
"Expected second dimension of freq_group_pattern to have size equal to "
f"position_dim, got freq_group_pattern shape {freq_group_pattern.size()} "
f"and position_dim={self.position_dim}"
)
n_freq_groups = freq_group_pattern.size(0)
half_head_dim = self.head_dim // 2
remainder = half_head_dim % n_freq_groups
if remainder > 0 and self.enforce_freq_groups_equal:
raise ValueError(
f"RoPE encodings ({half_head_dim}) not evenly divisible by frequency "
f"groups ({n_freq_groups})"
)
def validate_grouped_freqs(self, freqs: list[Tensor], encoding_ranges: Tensor):
# Validate number of frequency groups
n_freq_groups = len(freqs)
if self.freq_group_pattern.size(0) != n_freq_groups:
raise ValueError(
"Expected the first dimension of freq_group_pattern (shape: "
f"{self.freq_group_pattern.shape}) to have size equal to the length of the"
f"freqs list ({len(freqs)})"
)
# Validate head_dim is consistent
half_head_dim_list = [freqs.size(2) for freqs in freqs]
if len(set(half_head_dim_list)) != 1 and self.enforce_freq_groups_equal:
raise ValueError(
"Expected tensors in freqs to all have the same number of "
f"RoPE encodings; got {half_head_dim_list}"
)
# Validate n_heads is consistent
n_heads_list = [freqs.size(1) for freqs in freqs]
n_heads_set = set(n_heads_list)
if not (
len(n_heads_set) == 1
or (len(n_heads_set) == 2 and len(n_heads_set - set((1,))) == 1)
):
raise ValueError(
"Expected tensors in freqs to have number of attention heads "
f"all equal and/or 1, got {n_heads_list}"
)
# Validate encoding ranges
if encoding_ranges.size(0) != n_freq_groups:
raise ValueError(
"Expected first dim of encoding_ranges to be equal to n_freq_groups "
f"({n_freq_groups}), got shape {encoding_ranges}"
)
if not (
torch.all(encoding_ranges[:, 0] <= encoding_ranges[:, 1])
and torch.all(encoding_ranges[:-1, 1] == encoding_ranges[1:, 0])
):
raise ValueError(
"Expected encoding_ranges to be a 2D tensor of contiguous, "
f"non-overlapping slices, got {encoding_ranges}"
)
def forward(
self,
query: Tensor,
query_pos: Tensor,
key: Optional[Tensor] = None,
key_pos: Optional[Tensor] = None,
) -> Union[Tensor, tuple[Tensor, Tensor]]:
"""Apply rotary position embeddings to query and optionally key tensors.
Applies position-dependent rotations to query and key tensors based on
their associated position information.
Args:
query (Tensor): Query tensor of shape [..., embed_dim].
query_pos (Tensor): Position tensor for query of shape
[..., position_dim]. The leading dimensions must match those of query.
It is assumed that the positions are NOT normalized to the standard
[0, 1] range and are instead the true positions.
key (Optional[Tensor]): Key tensor of shape [..., embed_dim]. Default: None
key_pos (Optional[Tensor]): Position tensor for key of shape
[..., position_dim]. If None and key is provided, query_pos will be
used. It is assumed that the positions are NOT normalized to the
standard [0, 1] range and are instead the true positions. Default: None
Returns:
Union[Tensor, tuple[Tensor, Tensor]]:
- If key is None: Rotated query tensor of same shape as input query.
- If key is provided: Tuple of (rotated query, rotated key) tensors.
Note:
- For query/key embeddings with a regular grid structure, a default
position grid may be obtained from the static method `position_grid`.
Raises:
ValueError: If the tensor shapes are incompatible.
Warns:
UserWarning: If position coordinates appear to be normalized
(in [0,1] range).
"""
self.shape_check(query, query_pos)
if query_pos.numel() > 0 and query_pos.min() > 0.0 and query_pos.max() <= 1.0:
warnings.warn(
"Expected un-normalized (i.e., not inside [0,1]) coordinates "
"for position but found potentially normalized coordinates. "
"Did you accidentally pass in normalized coordinates?\n(Your coord "
f"range: [{query_pos.min().item(), query_pos.max().item()}])",
UserWarning,
)
if key_pos is not None: # Check key if present
assert key is not None
self.shape_check(key, key_pos)
# Construct full frequency tensor from component parameters
freq_tensor = self.grouped_rope_freqs_tensor(self.freqs)
# unstack heads
query = query.reshape(query.shape[:-1] + (self.n_heads, self.head_dim))
if key is not None:
key = key.reshape(key.shape[:-1] + (self.n_heads, self.head_dim))
key_rotated = None
# select proper path
if self.forward_only or self.use_checkpointing:
# use end-to-end function
if key is not None and key_pos is None:
# query and key share same positions
query_rotated, key_rotated = self.apply_rope_endtoend(
query, query_pos, freq_tensor, self.forward_only, self.inplace, key
)
else:
query_rotated = self.apply_rope_endtoend(query, query_pos, freq_tensor, self.forward_only, self.inplace)
if key is not None:
assert key_pos is not None
key_rotated = self.apply_rope_endtoend(
key, key_pos, freq_tensor, self.forward_only, self.inplace
)
else:
query_rot_vec = self.calculate_rope(query_pos, freq_tensor)
query_rotated = self.rotate_embeddings(query, query_rot_vec)
if key is not None:
if key_pos is None:
key_rotated = self.rotate_embeddings(key, query_rot_vec)
else:
key_rot_vec = self.calculate_rope(key_pos, freq_tensor)
key_rotated = self.rotate_embeddings(key, key_rot_vec)
assert isinstance(query_rotated, Tensor)
# stack heads back
query_rotated = query_rotated.view(query_rotated.shape[:-2] + (self.embed_dim,))
if key is None:
return query_rotated
assert isinstance(key_rotated, Tensor)
key_rotated = key_rotated.view(key_rotated.shape[:-2] + (self.embed_dim,))
return query_rotated, key_rotated
@staticmethod
def position_grid(
embeddings_shape: Union[Sequence[int], Tensor],
start_dim: int = 1,
end_dim: int = -1,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
) -> Tensor:
"""Generates a regularly-spaced grid of positions based on the input shape.
This function may be used to generate a tensor of positions corresponding to
the tensor indices of each element in the embeddings tensor. This is
potentially useful for regularly-spaced queries and keys, such as embeddings
corresponding to text tokens or image pixels. The exact form of the position
grid tensor is torch.stack(torch.meshgrid(
*[torch.arange(size) for size in embeddings_shape[start_dim:end_dim]],
indexing="ij"
), dim=-1)
Args:
embeddings_shape (Sequence[int] | Tensor): The full shape of the embeddings tensor.
start_dim (int, optional): Start index of the position dimensions in
embeddings_shape, inclusive. Defaults to 1 (i.e., one batch dim).
end_dim (int, optional): End index of the position dimensions in
embeddings_shape, exclusive. Defaults to -1 (i.e., one feature dim).
device (Optional[Union[str, torch.device]]): The device on which to create the
tensor. Defaults to None (i.e., default device).
dtype (Optional[torch.device]): The dtype for the created tensor. Defaults
to None (i.e., default dtype).
Returns:
Tensor: Created position grid tensor, of shape
[*embeddings_shape[start_dim:end_dim],
len(embeddings_shape[start_dim:end_dim])]
"""
grid = torch.stack(
torch.meshgrid(
*[
torch.arange(int(size), device=device, dtype=dtype)
for size in embeddings_shape[start_dim:end_dim]
],
indexing="ij",
),
dim=-1,
)
return grid
def grouped_rope_freqs_tensor(
self,
grouped_rope_freqs: Union[list[Tensor], nn.ParameterList],
) -> Tensor:
"""Use frequency group information to build the full RoPE frequency tensor that
is multiplied by the positions to produce RoPE encodings.
This function takes the per-group RoPE frequencies to construct the RoPE
frequency tensor. The RoPE frequency tensor has shape
[position_dim, n_freq_groups, n_heads, head_dim/2], and is zero at positions
where a position dimension is not included in a frequency group. The
frequencies are stored separately per frequency group in tensors of shape
[position_dim_g, n_heads, group_encoding_dim] because each frequency group may
have a different number of active position dimensions and/or assigned encoding
dimensions.
Args:
grouped_rope_freqs (list[Tensor]): List of per-group frequency tensors, as
generated by init_nd_freqs, each of shape
[
position_dim_g,
n_heads,
{head_dim//(2 * n_freq_groups), head_dim//(2 * n_freq_groups) + 1}
],
where position_dim_g is the number of position dimensions included in
frequency group g.
Returns:
Tensor: RoPE frequency tensor of shape
[position_dim, n_freq_groups, n_heads, head_dim/2] or
[position_dim, n_freq_groups, 1, head_dim/2], with nonzero
elements corresponding to position dimensions included in each
frequency group. It may be passed to `calculate_rope` with the
positions tensor to compute RoPE encodings.
"""
if isinstance(grouped_rope_freqs, Tensor):
grouped_rope_freqs = [grouped_rope_freqs]
# Create output tensor
rope_freqs = grouped_rope_freqs[0].new_zeros(
self.position_dim,
self.n_freq_groups,
self.n_heads if not self.share_heads else 1,
self.head_dim // 2,
)
values = torch.cat([fg.flatten() for fg in grouped_rope_freqs])
rope_freqs.index_put_(
(
self.freq_pos_indices,
self.freq_group_indices,
self.freq_head_indices,
self.freq_enc_indices,
),
values,
)
return rope_freqs
@staticmethod
def calculate_rope(positions: Tensor, rope_freqs: Tensor) -> Tensor:
"""Creates rotation vectors from position coordinates and RoPE frequencies.
Transforms positional information into rotation vectors for RoPE.
Args:
positions (Tensor): Position tensor of shape [..., position_dim].
rope_freqs (Tensor): Frequency tensor for rotary encodings of shape
[position_dim, n_freq_groups, n_heads, head_dim/2].
Returns:
Tensor: Real-valued positional encodings of shape
[..., n_heads, head_dim/2].
"""
return calculate_rope(positions.to(rope_freqs), rope_freqs)
@staticmethod
def rotate_embeddings(
query_or_key: Tensor,
rope_encoding: Tensor,
forward_only: bool = False,
inplace: bool = False,
) -> Tensor:
"""Applies rotary embeddings to query or key tensor using complex
multiplication.
Rotates the query or key tensor using the rotation vectors via complex
multiplication.
Args:
query_or_key (Tensor): Query or key tensor of shape
[..., n_heads, head_dim].
rope_encoding (Tensor): Real-valued RoPE encoding tensor of shape
[..., n_heads, head_dim/2].
Returns:
Tensor: Rotated query or key tensor of same shape as input query_or_key.
"""
# Unsqueeze rope_encoding if needed
dim_diff = query_or_key.ndim - rope_encoding.ndim
if dim_diff > 0:
rope_encoding = rope_encoding.view((1,) * dim_diff + rope_encoding.shape)
if forward_only:
return rotate_embeddings_forward_only(
query_or_key, rope_encoding, inplace=inplace
)
return rotate_embeddings(query_or_key, rope_encoding)
def apply_rope_endtoend(
self,
embeddings: Tensor,
positions: Tensor,
rope_freqs: Tensor,
forward_only: bool = False,
inplace: bool = False,
key_embeddings: Optional[Tensor] = None
) -> Union[Tensor, tuple[Tensor, Tensor]]:
"""Memory-optimized calculation and application of RoPE from components.
Uses an end-to-end function to calculate RoPE rotation vectors from
position coordinates and RoPE frequencies, then apply them to query or key
embeddings.
Under the hood, this uses a custom autograd Function and explicit backward
calculation to save memory, at the cost of recalculating the RoPE rotation
vectors during the backward pass.
Args:
embeddings (Tensor): Embeddings tensor of shape
[..., n_heads, head_dim].
positions (Tensor): Position tensor of shape [..., position_dim].
rope_freqs (Tensor): Frequency tensor for rotary encodings of shape
[position_dim, n_freq_groups, n_heads, head_dim/2].
Returns:
Tensor: Rotated query or key tensor of same shape as input query_or_key.
"""
if forward_only:
return apply_rope_forward_only(
embeddings, positions.to(rope_freqs), rope_freqs, inplace=inplace,
self_attn_key_embeddings=key_embeddings
)
return apply_rope_checkpointed(
embeddings, positions.to(rope_freqs), rope_freqs, key_embeddings=key_embeddings
)
def shape_check(self, query_or_key: Tensor, query_or_key_pos: Tensor):
"""Validates the shapes of query/key and their position tensors.
Args:
query_or_key (Tensor): Query or key tensor of shape [..., embed_dim].
query_or_key_pos (Tensor): Position tensor of shape [..., position_dim].
Must be broadcastable to the shape of query_or_key.
Raises:
ValueError: If tensor shapes are incompatible.
"""
if not can_broadcast_shapes(
query_or_key.shape[:-1], query_or_key_pos.shape[:-1]
):
raise ValueError(
"Expected leading dims of query_or_key_pos to be broadcastable to "
"leading dims of query_or_key, but got shapes "
f"{query_or_key_pos.shape} and {query_or_key.shape}, respectively."
)
if query_or_key.shape[-1] != self.embed_dim:
raise ValueError(
"Expected query_or_key to have last dim equal to embed_dim "
f"(={self.embed_dim}), got {query_or_key.shape[-1]}"
)
if query_or_key_pos.shape[-1] != self.position_dim:
raise ValueError(
"Expected query_or_key_pos to have last dim equal to pos_dim "
f"(={self.position_dim}), got {query_or_key_pos.shape[-1]}"
)
def reset_parameters(self):
"""Resets frequency parameters"""
freqs, _ = init_nd_freqs(
self.position_dim,
self.head_dim,
self.n_heads if not self.share_heads else 1,
self.freq_group_pattern,
self.enforce_freq_groups_equal,
self._base_theta,
dtype=self.dtype,
device=self.freqs[0].device,
)
with torch.no_grad():
for param, init in zip(self.freqs, freqs):
param.copy_(init)
forward(query, query_pos, key=None, key_pos=None)
¶
Apply rotary position embeddings to query and optionally key tensors.
Applies position-dependent rotations to query and key tensors based on their associated position information.
Parameters: |
|
---|
Returns: |
|
---|
Note
- For query/key embeddings with a regular grid structure, a default
position grid may be obtained from the static method
position_grid
.
Raises: |
|
---|
Warns: |
|
---|
Source code in nd_rotary_encodings/position_encoding_layer/rope_encoding_layer.py
def forward(
self,
query: Tensor,
query_pos: Tensor,
key: Optional[Tensor] = None,
key_pos: Optional[Tensor] = None,
) -> Union[Tensor, tuple[Tensor, Tensor]]:
"""Apply rotary position embeddings to query and optionally key tensors.
Applies position-dependent rotations to query and key tensors based on
their associated position information.
Args:
query (Tensor): Query tensor of shape [..., embed_dim].
query_pos (Tensor): Position tensor for query of shape
[..., position_dim]. The leading dimensions must match those of query.
It is assumed that the positions are NOT normalized to the standard
[0, 1] range and are instead the true positions.
key (Optional[Tensor]): Key tensor of shape [..., embed_dim]. Default: None
key_pos (Optional[Tensor]): Position tensor for key of shape
[..., position_dim]. If None and key is provided, query_pos will be
used. It is assumed that the positions are NOT normalized to the
standard [0, 1] range and are instead the true positions. Default: None
Returns:
Union[Tensor, tuple[Tensor, Tensor]]:
- If key is None: Rotated query tensor of same shape as input query.
- If key is provided: Tuple of (rotated query, rotated key) tensors.
Note:
- For query/key embeddings with a regular grid structure, a default
position grid may be obtained from the static method `position_grid`.
Raises:
ValueError: If the tensor shapes are incompatible.
Warns:
UserWarning: If position coordinates appear to be normalized
(in [0,1] range).
"""
self.shape_check(query, query_pos)
if query_pos.numel() > 0 and query_pos.min() > 0.0 and query_pos.max() <= 1.0:
warnings.warn(
"Expected un-normalized (i.e., not inside [0,1]) coordinates "
"for position but found potentially normalized coordinates. "
"Did you accidentally pass in normalized coordinates?\n(Your coord "
f"range: [{query_pos.min().item(), query_pos.max().item()}])",
UserWarning,
)
if key_pos is not None: # Check key if present
assert key is not None
self.shape_check(key, key_pos)
# Construct full frequency tensor from component parameters
freq_tensor = self.grouped_rope_freqs_tensor(self.freqs)
# unstack heads
query = query.reshape(query.shape[:-1] + (self.n_heads, self.head_dim))
if key is not None:
key = key.reshape(key.shape[:-1] + (self.n_heads, self.head_dim))
key_rotated = None
# select proper path
if self.forward_only or self.use_checkpointing:
# use end-to-end function
if key is not None and key_pos is None:
# query and key share same positions
query_rotated, key_rotated = self.apply_rope_endtoend(
query, query_pos, freq_tensor, self.forward_only, self.inplace, key
)
else:
query_rotated = self.apply_rope_endtoend(query, query_pos, freq_tensor, self.forward_only, self.inplace)
if key is not None:
assert key_pos is not None
key_rotated = self.apply_rope_endtoend(
key, key_pos, freq_tensor, self.forward_only, self.inplace
)
else:
query_rot_vec = self.calculate_rope(query_pos, freq_tensor)
query_rotated = self.rotate_embeddings(query, query_rot_vec)
if key is not None:
if key_pos is None:
key_rotated = self.rotate_embeddings(key, query_rot_vec)
else:
key_rot_vec = self.calculate_rope(key_pos, freq_tensor)
key_rotated = self.rotate_embeddings(key, key_rot_vec)
assert isinstance(query_rotated, Tensor)
# stack heads back
query_rotated = query_rotated.view(query_rotated.shape[:-2] + (self.embed_dim,))
if key is None:
return query_rotated
assert isinstance(key_rotated, Tensor)
key_rotated = key_rotated.view(key_rotated.shape[:-2] + (self.embed_dim,))
return query_rotated, key_rotated
position_grid(embeddings_shape, start_dim=1, end_dim=-1, device=None, dtype=None)
staticmethod
¶
Generates a regularly-spaced grid of positions based on the input shape.
This function may be used to generate a tensor of positions corresponding to the tensor indices of each element in the embeddings tensor. This is potentially useful for regularly-spaced queries and keys, such as embeddings corresponding to text tokens or image pixels. The exact form of the position grid tensor is torch.stack(torch.meshgrid( *[torch.arange(size) for size in embeddings_shape[start_dim:end_dim]], indexing="ij" ), dim=-1)
Parameters: |
|
---|
Returns: |
|
---|
Source code in nd_rotary_encodings/position_encoding_layer/rope_encoding_layer.py
@staticmethod
def position_grid(
embeddings_shape: Union[Sequence[int], Tensor],
start_dim: int = 1,
end_dim: int = -1,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
) -> Tensor:
"""Generates a regularly-spaced grid of positions based on the input shape.
This function may be used to generate a tensor of positions corresponding to
the tensor indices of each element in the embeddings tensor. This is
potentially useful for regularly-spaced queries and keys, such as embeddings
corresponding to text tokens or image pixels. The exact form of the position
grid tensor is torch.stack(torch.meshgrid(
*[torch.arange(size) for size in embeddings_shape[start_dim:end_dim]],
indexing="ij"
), dim=-1)
Args:
embeddings_shape (Sequence[int] | Tensor): The full shape of the embeddings tensor.
start_dim (int, optional): Start index of the position dimensions in
embeddings_shape, inclusive. Defaults to 1 (i.e., one batch dim).
end_dim (int, optional): End index of the position dimensions in
embeddings_shape, exclusive. Defaults to -1 (i.e., one feature dim).
device (Optional[Union[str, torch.device]]): The device on which to create the
tensor. Defaults to None (i.e., default device).
dtype (Optional[torch.device]): The dtype for the created tensor. Defaults
to None (i.e., default dtype).
Returns:
Tensor: Created position grid tensor, of shape
[*embeddings_shape[start_dim:end_dim],
len(embeddings_shape[start_dim:end_dim])]
"""
grid = torch.stack(
torch.meshgrid(
*[
torch.arange(int(size), device=device, dtype=dtype)
for size in embeddings_shape[start_dim:end_dim]
],
indexing="ij",
),
dim=-1,
)
return grid
Utilities¶
prep_multilevel_positions(spatial_positions, batch_indices, level_indices, level_spatial_shapes)
¶
Standardizes positional coordinates across multiple resolution levels.
Converts indices or positions from multiple resolution levels to a standardized coordinate system by rescaling each level to match the finest level's resolution. This enables consistent position encoding across hierarchical feature maps.
Parameters: |
|
---|
Returns: |
|
---|
Raises: |
|
---|
Source code in nd_rotary_encodings/position_encoding_layer/utils.py
def prep_multilevel_positions(
spatial_positions: Tensor,
batch_indices: Tensor,
level_indices: Tensor,
level_spatial_shapes: Tensor,
) -> Tensor:
"""Standardizes positional coordinates across multiple resolution levels.
Converts indices or positions from multiple resolution levels to a standardized
coordinate system by rescaling each level to match the finest level's resolution.
This enables consistent position encoding across hierarchical feature maps.
Args:
spatial_positions (Tensor): Indices or positions of shape [..., position_dim],
where each row contains the N-D position of each point. If floating point,
they're treated as coordinates; if integer, they're treated as indices.
batch_indices (Tensor): Integer tensor of shape [...], containing the
batch index for each position in spatial_positions.
level_indices (Tensor): Integer tensor of shape [...], containing the
level index for each position in spatial_positions.
level_spatial_shapes (Tensor): Tensor of shape [num_levels, 2] or
[batch_size, num_levels, 2] specifying the spatial dimensions
(height, width) of each level.
Returns:
Tensor: Rescaled positions of shape [..., position_dim + 1] with floating
point dtype, where the last dimension has the level index concatenated onto
the end of the spatial coordinates, and the spatial coordinates are
standardized to the finest resolution level.
Raises:
ValueError: If tensors don't have the expected shape, dimensions, or dtypes.
"""
validate_atleast_nd(spatial_positions, 2, "spatial_positions")
batch_dims = spatial_positions.ndim - 1
validate_nd(batch_indices, batch_dims, "batch_indices")
validate_nd(level_indices, batch_dims, "level_indices")
if not torch.is_floating_point(spatial_positions):
# convert from indices to coordinates of pixel centers
spatial_positions = spatial_positions + 0.5
# batch, level, pos_dim or level, pos_dim
assert level_spatial_shapes.ndim in (2, 3)
# Initialize output tensor
multilevel_positions = spatial_positions.new_zeros(
spatial_positions.shape[:-1] + (spatial_positions.size(-1) + 1,)
)
# Early exit
if multilevel_positions.numel() == 0:
return multilevel_positions
if level_spatial_shapes.ndim == 2:
level_spatial_shapes = level_spatial_shapes.unsqueeze(0).expand(
int(torch.max(batch_indices).item()) + 1, -1, -1
)
batch_max_spatial_shape = level_spatial_shapes.max(-2)[0]
max_spatial_shapes = batch_max_spatial_shape[batch_indices]
indexed_spatial_shapes = level_spatial_shapes[batch_indices, level_indices]
# Fill in rescaled positions
multilevel_positions[..., :-1] = spatial_positions / (
indexed_spatial_shapes / max_spatial_shapes
)
# Fill in level indices
multilevel_positions[..., -1] = level_indices.to(multilevel_positions)
return multilevel_positions
get_multilevel_freq_group_pattern(position_dim, pattern_name, device=None)
¶
Get a predefined frequency group pattern for RoPE encodings of multilevel features.
Creates a frequency group pattern tensor for use with RoPEEncodingND based on predefined patterns that determine how spatial and level dimensions are encoded.
Parameters: |
|
---|
Returns: |
|
---|
Raises: |
|
---|
Source code in nd_rotary_encodings/position_encoding_layer/utils.py
def get_multilevel_freq_group_pattern(
position_dim: int, pattern_name: Union[str, FreqGroupPattern], device=None
) -> Tensor:
"""Get a predefined frequency group pattern for RoPE encodings of multilevel features.
Creates a frequency group pattern tensor for use with RoPEEncodingND based on
predefined patterns that determine how spatial and level dimensions are encoded.
Args:
position_dim (int): Spatial dimension of the features to be encoded (2 for 2D
images, etc.). The output tensor will have this many spatial dimensions
plus 1 dimension for the feature level
pattern_name (Union[str, FreqGroupPattern]): Pattern to use, either as a string
or enum value. Options:
- "single" or FreqGroupPattern.SINGLE: All dimensions (*spatial, level) in
a single frequency group
- "partition" or FreqGroupPattern.PARTITION: Spatial dimensions and level
in separate groups
- "closure" or FreqGroupPattern.CLOSURE: Three groups - Spatial, level,
and (*spatial, level)
device (torch.device, optional): Device for the created tensor. Defaults to None.
Returns:
Tensor: Boolean tensor encoding the frequency group pattern, of shape
[n_freq_groups, position_dim + 1]
Raises:
ValueError: If an unrecognized pattern name is provided.
"""
if isinstance(pattern_name, FreqGroupPattern):
pattern_name = pattern_name.value
if pattern_name == "single":
out = torch.ones(1, position_dim + 1, device=device)
elif pattern_name == "partition":
out = torch.zeros(2, position_dim + 1, device=device)
out[0, :-1] = True # Spatial dimensions in one group
out[1, -1] = True # Level dimension in second group
elif pattern_name == "closure":
out = torch.zeros(3, position_dim + 1, device=device)
out[0, :-1] = True # Spatial dimensions in one group
out[1, -1] = True # Level dimension in second group
out[2, :] = True # Third group has all dimensions
else:
raise ValueError(f"Unrecognized pattern_name {pattern_name}")
return out