Functional API

The functional API contains the low-level operations used by RoPEEncodingND. They include custom autograd Functions, TorchScript-optimized forward and backward kernels, and various wrappers.

The application of RoPE from embeddings, positions, and RoPE frequencies is broken into two steps:

1. Compute the RoPE rotation tensor from positions and frequencies
2. Rotate the embeddings by multiplying them by the RoPE rotation tensor.

The RoPE rotation tensor is fast to compute but may potentially be very large if the sequence dimension and/or batch size are large, so allowing it to be recomputed during the backward pass is a key optimization.


Forward and backward kernels

calculate_rope(positions, rope_freqs)

Computes the positional encoding for embeddings tensors using the provided positions and frequency values.

This function calculates the rotary position encoding by matrix-multiplying embedding positions with rotary frequency encodings, then summing over frequency groups. The returned positional encoding will be in real space, and must be converted to complex coordinates with e.g. torch.polar before multiplying with the complex representation of the embedding tensor (this conversion is handled by rotate_embeddings). This function may be used in combination with the others in its module for a memory-efficient RoPE application over many positions. This implementation allows for grouping of position dimensions into specific frequency groups. The intention is to allow dimensions with potentially different spatial characteristics (e.g., x and y vs time for videos) to be grouped separately. This generalization is experimental and under active research. If dimension i is not in frequency group j, then rope_freqs[i, j] should be 0. For traditional RoPE, keep n_freq_groups as 1.

Parameters:
  • positions (Tensor) –

    Position information for each embedding element of shape [..., position_dim], where ... are arbitrary batch dimensions and position_dim is the dimensionality of the position representation.

  • rope_freqs (Tensor) –

    Frequency values for rotary encodings of shape [position_dim, n_freq_groups, n_heads, head_dim/2], where n_freq_groups and n_heads can be 1 for broadcasting.

Returns:
  • Tensor( Tensor ) –

    Computed positional encoding of shape [..., n_heads, head_dim/2]

Source code in nd_rotary_encodings/functional/forward_backward_fns.py
@torch.jit.script
def calculate_rope(positions: Tensor, rope_freqs: Tensor) -> Tensor:
    """Computes the positional encoding for embeddings tensors using the provided
    positions and frequency values.

    This function calculates the rotary position encoding by matrix-multiplying
    embedding positions with rotary frequency encodings, then summing over frequency
    groups.
    The returned positional encoding will be in real space, and must be converted
    to complex coordinates with e.g. torch.polar before multiplying with the
    complex representation of the embedding tensor (this conversion is handled by
    rotate_embeddings).
    This function may be used in combination with the others in its module for a
    memory-efficient RoPE application over many positions.
    This implementation allows for grouping of position dimensions into specific
    frequency groups. The intention is to allow dimensions with potentially different
    spatial characteristics (e.g., x and y vs time for videos) to be grouped
    separately. This generalization is experimental and under active research.
    If dimension i is not in frequency group j, then rope_freqs[i, j] should be 0.
    For traditional RoPE, keep n_freq_groups as 1.

    Args:
        positions (Tensor): Position information for each embedding element of shape
            [..., position_dim], where ... are arbitrary batch dimensions and
            position_dim is the dimensionality of the position representation.
        rope_freqs (Tensor): Frequency values for rotary encodings of shape
            [position_dim, n_freq_groups, n_heads, head_dim/2], where n_freq_groups
            and n_heads can be 1 for broadcasting.

    Returns:
        Tensor: Computed positional encoding of shape
            [..., n_heads, head_dim/2]
    """
    validate_at_least_nd(positions, "positions", 2)
    validate_4d(rope_freqs, "rope_freqs")
    validate_position_dim(positions, rope_freqs)

    batch_dims = positions.shape[:-1]
    position_dim = positions.size(-1)
    _, n_freq_groups, n_heads, half_head_dim = rope_freqs.shape

    # flatten batch dimensions
    positions_flat = positions.reshape(-1, position_dim)

    # [position_dim, n_freq_groups*n_heads*head_dim/2]
    rope_freqs_flat = rope_freqs.reshape(position_dim, -1)

    # Compute position encoding
    rope_encoding = torch.mm(positions_flat, rope_freqs_flat)
    # shape: [prod(batch_dims), n_freq_groups*n_heads*head_dim/2]

    # reshape back to input batch dims
    output_shape = batch_dims + (n_freq_groups, n_heads, half_head_dim)
    rope_encoding = rope_encoding.view(output_shape)

    # Sum over frequency groups
    # [*batch_dims, n_heads, head_dim/2]
    rope_encoding = rope_encoding.sum(dim=-3)
    return rope_encoding

calculate_rope_backward(grad_rope_encoding, positions, rope_freqs, needs_grad_positions, needs_grad_rope_freqs)

Calculates gradients for the calculate_rope function.

This function implements the backward pass for the calculation of the rotary positional encoding tensor that gets multiplied with the query/key tensor. It propagates the gradients from rope_encoding to positions and rope_freqs.

Parameters:
  • grad_rope_encoding (Tensor) –

    Real-valued gradient of loss with respect to the positional encoding, of shape [..., n_heads, head_dim/2]

  • positions (Tensor) –

    Position tensor from the forward pass, of shape [..., position_dim]

  • rope_freqs (Tensor) –

    Frequency values tensor from the forward pass, of shape [position_dim, n_freq_groups, n_heads, head_dim/2], with n_freq_groups and/or n_heads also allowed to be 1.

  • needs_grad_positions (bool) –

    Whether grad for positions is required

  • needs_grad_rope_freqs (bool) –

    Whether grad for rope_freqs is required

Returns:
  • tuple[Optional[Tensor], Optional[Tensor]]

    tuple[Optional[Tensor], Optional[Tensor]]: - grad_positions: Gradient tensor for positions of shape [..., position_dim], or None if not needed - grad_rope_freqs: Gradient tensor for rope frequencies of same shape as input tensor rope_freqs, or None if not needed

Source code in nd_rotary_encodings/functional/forward_backward_fns.py
@torch.jit.script
def calculate_rope_backward(
    grad_rope_encoding: Tensor,
    positions: Tensor,
    rope_freqs: Tensor,
    needs_grad_positions: bool,
    needs_grad_rope_freqs: bool,
) -> tuple[Optional[Tensor], Optional[Tensor]]:
    """Calculates gradients for the calculate_rope function.

    This function implements the backward pass for the calculation of the rotary
    positional encoding tensor that gets multiplied with the query/key tensor. It
    propagates the gradients from rope_encoding to positions and rope_freqs.

    Args:
        grad_rope_encoding (Tensor): Real-valued gradient of loss with respect to
            the positional encoding, of shape
            [..., n_heads, head_dim/2]
        positions (Tensor): Position tensor from the forward pass, of shape
            [..., position_dim]
        rope_freqs (Tensor): Frequency values tensor from the forward pass, of shape
            [position_dim, n_freq_groups, n_heads, head_dim/2], with n_freq_groups
            and/or n_heads also allowed to be 1.
        needs_grad_positions (bool): Whether grad for positions is required
        needs_grad_rope_freqs (bool): Whether grad for rope_freqs is required

    Returns:
        tuple[Optional[Tensor], Optional[Tensor]]:
            - grad_positions: Gradient tensor for positions of shape
              [..., position_dim], or None if not needed
            - grad_rope_freqs: Gradient tensor for rope frequencies of same
              shape as input tensor rope_freqs, or None if not needed
    """
    validate_at_least_nd(positions, "positions", 2)
    validate_at_least_nd(grad_rope_encoding, "grad_rope_encoding", 3)
    validate_4d(rope_freqs, "rope_freqs")
    validate_position_dim(positions, rope_freqs)

    positions_in_shape = positions.shape
    rope_freqs_in_shape = rope_freqs.shape

    positions = positions.expand(grad_rope_encoding.shape[:-2] + (positions.size(-1),))
    n_heads = grad_rope_encoding.size(-2)
    rope_freqs = rope_freqs.expand(-1, -1, n_heads, -1)

    batch_dims = positions.shape[:-1]
    position_dim = positions.size(-1)
    _, n_freq_groups, _, half_head_dim = rope_freqs.shape

    # Check for no grads needed
    if not needs_grad_positions and not needs_grad_rope_freqs:
        # Early return
        return None, None

    # Backward of sum: distribute gradient across n_freq_groups
    grad_mm_result = grad_rope_encoding.unsqueeze(-3).expand(
        (-1,) * len(batch_dims) + (n_freq_groups, -1, -1)
    )

    # Reshape to match the mm result
    grad_mm_result = grad_mm_result.reshape(-1, n_freq_groups * n_heads * half_head_dim)

    # Flatten inputs as in forward pass
    positions_flat = positions.reshape(-1, position_dim)
    rope_freqs_flat = rope_freqs.reshape(position_dim, -1)

    # Gradient for matrix multiplication: If C = A @ B
    # Then grad_A = grad_C @ B^T and grad_B = A^T @ grad_C
    if needs_grad_positions:
        grad_positions_flat = torch.mm(grad_mm_result, rope_freqs_flat.t())
        grad_positions = grad_positions_flat.view(batch_dims + (position_dim,))
        grad_positions = grad_positions.sum_to_size(positions_in_shape)
    else:
        grad_positions = None

    if needs_grad_rope_freqs:
        grad_rope_freqs_flat = torch.mm(positions_flat.t(), grad_mm_result)
        grad_rope_freqs = grad_rope_freqs_flat.view(
            position_dim, n_freq_groups, n_heads, half_head_dim
        )
        grad_rope_freqs = grad_rope_freqs.sum_to_size(rope_freqs_in_shape)
    else:
        grad_rope_freqs = None

    return grad_positions, grad_rope_freqs

rotate_embeddings(embeddings, rope_encoding, needs_autograd=True)

Applies rotary position encoding (RoPE) to the embeddings tensor via complex multiplication.

Parameters:
  • embeddings (Tensor) –

    Embeddings tensor to be rotated (usually a query or key tensor) of real dtype and shape [..., n_heads, head_dim]

  • rope_encoding (Tensor) –

    Position encoding of real dtype and shape [..., n_heads, head_dim/2] or [..., 1, head_dim/2] (broadcasted over heads)

  • needs_autograd (bool, default: True ) –

    If you need this function to be tracked by autograd, keep this at True. If False, additional autograd-incompatible memory optimizations are applied. The function will fail in the backward pass if this option is False, so the optimizations are not applied by default for safety.

Returns:
  • Tensor
    • embeddings_rotated (Tensor): Embedding tensor after rotation, of shape [..., n_heads, head_dim] and real dtype
Note
  • This function uses Pytorch's complex number operations, which only support single and double precision. If embeddings and rope_encoding are half precision or lower, they are temporarily upcasted to float32 for this function, and the output is downcasted back to embeddings's original dtype before returning it.
Source code in nd_rotary_encodings/functional/forward_backward_fns.py
@torch.jit.script
def rotate_embeddings(
    embeddings: Tensor, rope_encoding: Tensor, needs_autograd: bool = True
) -> Tensor:
    """Applies rotary position encoding (RoPE) to the embeddings tensor via
    complex multiplication.

    Args:
        embeddings (Tensor): Embeddings tensor to be rotated (usually a query or
            key tensor) of real dtype and shape [..., n_heads, head_dim]
        rope_encoding (Tensor): Position encoding of real dtype and shape
            [..., n_heads, head_dim/2] or
            [..., 1,       head_dim/2] (broadcasted over heads)
        needs_autograd (bool): If you need this function to be tracked by autograd,
            keep this at True. If False, additional autograd-incompatible
            memory optimizations are applied. The function will fail in the backward
            pass if this option is False, so the optimizations are not applied by
            default for safety.

    Returns:
        - embeddings_rotated (Tensor): Embedding tensor after rotation, of shape
            [..., n_heads, head_dim] and real dtype

    Note:
        - This function uses Pytorch's complex number operations, which only support
            single and double precision. If `embeddings` and `rope_encoding` are
            half precision or lower, they are temporarily upcasted to float32 for
            this function, and the output is downcasted back to `embeddings`'s original
            dtype before returning it.
    """
    validate_same_ndims(embeddings, "embeddings", rope_encoding, "rope_encoding")
    validate_at_least_nd(embeddings, "embeddings", 3)
    validate_real(embeddings, "embeddings")
    validate_real(rope_encoding, "rope_encoding")
    validate_head_dim(embeddings, rope_encoding)

    # Save original dtype
    embeddings_dtype = embeddings.dtype

    # Upcast if needed
    embeddings = _upcast_if_needed(embeddings)
    rope_encoding = _upcast_if_needed(rope_encoding)

    # Convert to complex and apply rotation
    emb_complex_shape = embeddings.shape[:-1] + (embeddings.size(-1) // 2, 2)
    embeddings_complex = torch.view_as_complex(embeddings.reshape(emb_complex_shape))
    rope_encoding_complex = torch.polar(torch.ones_like(rope_encoding), rope_encoding)

    # multiply and convert back to real
    if needs_autograd:
        embeddings_rotated = embeddings_complex * rope_encoding_complex
    else:
        # can use an in-place op rather than creating a new tensor
        embeddings_rotated = embeddings_complex
        embeddings_rotated *= rope_encoding_complex
    embeddings_rotated = torch.view_as_real(embeddings_rotated).reshape_as(embeddings)

    # Cast back
    embeddings_rotated = embeddings_rotated.to(embeddings_dtype)

    return embeddings_rotated

rotate_embeddings_backward(grad_embeddings_rotated, embeddings, rope_encoding, needs_grad_embeddings=True, needs_grad_rope_encoding=True, needs_autograd=True)

Perform the backward pass of applying rotary positional encoding (RoPE)

Computes gradients through complex number operations used in the RoPE forward pass. For complex multiplication z = x * y, the gradients are: dL/dx = dL/dz * conj(y) and dL/dy = dL/dz * conj(x).

Parameters:
  • grad_embeddings_rotated (Tensor) –

    Gradient of loss with respect to rotated embeddings, of shape [..., n_heads, head_dim]

  • embeddings (Tensor) –

    Original, un-rotated embeddings tensor of real dtype and shape [..., n_heads, head_dim].

  • rope_encoding (Tensor) –

    Real representation of positional encodings of real dtype and shape [..., n_heads, head_dim/2] or [..., 1, head_dim/2]

  • needs_grad_embeddings (bool, default: True ) –

    Whether gradients for embeddings are needed. Default: True

  • needs_grad_rope_encoding (bool, default: True ) –

    Whether gradients for positional encodings are needed. Default: True

  • needs_autograd (bool, default: True ) –

    If you need this function to be tracked by autograd, keep this at True. If False, additional autograd-incompatible memory optimizations are applied. The function will fail in the backward pass if this option is False, so the optimizations are not applied by default for safety.

Returns:
  • grad_embeddings( Tensor ) –

    Gradient tensor for the unrotated embeddings, of shape [..., n_heads, head_dim] and real dtype, or None if not needed

  • grad_rope_encoding( Tensor ) –

    Gradient tensor for the positional encodings of real dtype and shape [..., n_heads, head_dim/2] or [..., 1, head_dim/2], or None if not needed

Note
  • This function uses Pytorch's complex number operations, which only support single and double precision. If any of the input tensors are half precision or lower, they are temporarily upcasted to float32 for this function, and the output gradients are downcasted back to the original dtype of embeddings and rope_encoding, respectively, before returning them.
Source code in nd_rotary_encodings/functional/forward_backward_fns.py
@torch.jit.script
def rotate_embeddings_backward(
    grad_embeddings_rotated: Tensor,
    embeddings: Tensor,
    rope_encoding: Tensor,
    needs_grad_embeddings: bool = True,
    needs_grad_rope_encoding: bool = True,
    needs_autograd: bool = True,
) -> tuple[Optional[Tensor], Optional[Tensor]]:
    """Perform the backward pass of applying rotary positional encoding (RoPE)

    Computes gradients through complex number operations used in the RoPE
    forward pass. For complex multiplication z = x * y, the gradients are:
    dL/dx = dL/dz * conj(y) and dL/dy = dL/dz * conj(x).

    Args:
        grad_embeddings_rotated (Tensor): Gradient of loss with respect to rotated
            embeddings, of shape [..., n_heads, head_dim]
        embeddings (Tensor): Original, un-rotated embeddings tensor of real dtype and
            shape [..., n_heads, head_dim].
        rope_encoding (Tensor): Real representation of positional encodings
            of real dtype and shape
            [..., n_heads, head_dim/2] or
            [..., 1,       head_dim/2]
        needs_grad_embeddings (bool): Whether gradients for embeddings are needed.
            Default: True
        needs_grad_rope_encoding (bool): Whether gradients for positional encodings
            are needed. Default: True
        needs_autograd (bool): If you need this function to be tracked by autograd,
            keep this at True. If False, additional autograd-incompatible
            memory optimizations are applied. The function will fail in the backward
            pass if this option is False, so the optimizations are not applied by
            default for safety.

    Returns:
        grad_embeddings (Tensor): Gradient tensor for the unrotated embeddings,
            of shape [..., n_heads, head_dim] and real dtype,
            or None if not needed
        grad_rope_encoding (Tensor): Gradient tensor for the positional encodings
            of real dtype and shape
            [..., n_heads, head_dim/2] or
            [..., 1,       head_dim/2], or None if not needed

    Note:
        - This function uses Pytorch's complex number operations, which only support
            single and double precision. If any of the input tensors are
            half precision or lower, they are temporarily upcasted to float32 for
            this function, and the output gradients are downcasted back to the original
            dtype of `embeddings` and `rope_encoding`, respectively, before returning
            them.
    """
    validate_same_ndims(embeddings, "embeddings", rope_encoding, "rope_encoding")
    validate_real(grad_embeddings_rotated, "grad_embeddings_rotated")
    validate_real(embeddings, "embeddings")
    validate_real(rope_encoding, "rope_encoding")

    if grad_embeddings_rotated.shape != embeddings.shape:
        raise ValueError(
            "Expected grad_embeddings_rotated and embeddings to have the same shape, "
            f"got {grad_embeddings_rotated.shape} and {embeddings.shape}"
        )
    validate_head_dim(embeddings, rope_encoding)

    # Check for no grads needed
    if not needs_grad_embeddings and not needs_grad_rope_encoding:
        # Early return
        return None, None

    # Save input dtypes
    embeddings_dtype = embeddings.dtype
    rope_encoding_dtype = rope_encoding.dtype

    # Upcast if needed
    grad_embeddings_rotated = _upcast_if_needed(grad_embeddings_rotated)
    embeddings = _upcast_if_needed(embeddings)
    rope_encoding = _upcast_if_needed(rope_encoding)

    # Convert grad_tensor_rotated to complex
    to_complex_shape = grad_embeddings_rotated.shape[:-1] + (
        grad_embeddings_rotated.size(-1) // 2,
        2,
    )
    grad_embeddings_rotated_complex = torch.view_as_complex(
        grad_embeddings_rotated.reshape(to_complex_shape).contiguous()
    )

    # Complex multiplication gradient
    # For z = x * y, we have dL/dx = dL/dz * conj(y) and dL/dy = dL/dz * conj(x)

    # Unconditionally recompute complex version of rope_encoding tensor since it's
    # required by both output grads
    rope_encoding_complex = torch.polar(torch.ones_like(rope_encoding), rope_encoding)

    # Gradient for embeddings tensor
    if needs_grad_embeddings:
        if needs_autograd or needs_grad_rope_encoding:
            grad_emb_complex = (
                grad_embeddings_rotated_complex * rope_encoding_complex.conj()
            )
        else:
            # Can modify tensor in-place rather than creating a new one
            # Need to check needs_grad_rope_encoding because we'll need
            # grad_embeddings_rotated_complex in that branch
            grad_emb_complex = grad_embeddings_rotated_complex
            grad_emb_complex *= rope_encoding_complex.conj()
        grad_embeddings = torch.view_as_real(grad_emb_complex).reshape_as(
            grad_embeddings_rotated
        )
        grad_embeddings = grad_embeddings.to(embeddings_dtype)  # downcast
    else:
        grad_embeddings = None

    # Gradient for position encoding
    if needs_grad_rope_encoding:
        # Recompute complex version of embeddings tensor
        emb_complex_shape = embeddings.shape[:-1] + (embeddings.size(-1) // 2, 2)
        embeddings_complex = torch.view_as_complex(
            embeddings.reshape(emb_complex_shape)
        )

        # Compute gradient with respect to rope_encoding_complex
        if needs_autograd:
            grad_rope_encoding_complex = (
                grad_embeddings_rotated_complex * embeddings_complex.conj()
            )
        else:
            # Can modify tensor in-place rather than creating a new one
            grad_rope_encoding_complex = grad_embeddings_rotated_complex
            grad_rope_encoding_complex *= embeddings_complex.conj()

        # Check if broadcasting happened
        is_broadcasted = (
            rope_encoding_complex.size(-2) == 1 and embeddings_complex.size(-2) > 1
        )

        if is_broadcasted:
            # Sum gradients across broadcasted dimension (heads)
            grad_rope_encoding_complex = grad_rope_encoding_complex.sum(
                dim=-2, keepdim=True
            )

        # Then compute gradient with respect to rope_encoding (the phase angle)
        # Since rope_encoding_complex = exp(i*rope_encoding), the gradient is:
        # dL/d(rope_encoding)
        #   = Im(dL/d(rope_encoding_complex) / rope_encoding_complex)
        if needs_autograd:
            grad_rope_encoding = (
                grad_rope_encoding_complex / rope_encoding_complex
            ).imag
        else:
            # Can modify tensor in-place rather than creating a new one
            grad_rope_encoding = grad_rope_encoding_complex
            grad_rope_encoding /= rope_encoding_complex
            grad_rope_encoding = grad_rope_encoding.imag

        grad_rope_encoding = grad_rope_encoding.to(rope_encoding_dtype)  # downcast
    else:
        grad_rope_encoding = None

    return grad_embeddings, grad_rope_encoding

Checkpointed wrappers

calculate_rope_checkpointed(positions, rope_freqs)

Memory-efficient differentiable version of calculate_rope.

This wrapper avoids building a full autograd graph inside the scripted kernel and stores only the positions and rope_freqs tensors required by the backward formula. This results in potentially large memory savings if the sequence length is large.

Parameters:
  • positions (Tensor) –

    Position information for each embedding element of shape [..., position_dim], where ... are arbitrary batch dimensions and position_dim is the dimensionality of the position representation.

  • rope_freqs (Tensor) –

    Frequency values for rotary encodings of shape [position_dim, n_freq_groups, n_heads, head_dim/2], where n_freq_groups and n_heads can be 1 for broadcasting.

Returns:
  • Tensor( Tensor ) –

    Computed positional encoding of shape [..., n_heads, head_dim/2]

Source code in nd_rotary_encodings/functional/autograd.py
def calculate_rope_checkpointed(positions: Tensor, rope_freqs: Tensor) -> Tensor:
    """Memory-efficient differentiable version of ``calculate_rope``.

    This wrapper avoids building a full autograd graph inside the scripted
    kernel and stores only the ``positions`` and ``rope_freqs`` tensors required
    by the backward formula. This results in potentially large memory savings
    if the sequence length is large.

    Args:
        positions (Tensor): Position information for each embedding element of shape
            [..., position_dim], where ... are arbitrary batch dimensions and
            position_dim is the dimensionality of the position representation.
        rope_freqs (Tensor): Frequency values for rotary encodings of shape
            [position_dim, n_freq_groups, n_heads, head_dim/2], where n_freq_groups
            and n_heads can be 1 for broadcasting.

    Returns:
        Tensor: Computed positional encoding of shape
            [..., n_heads, head_dim/2]
    """
    out = CalculateRopeFunction.apply(positions, rope_freqs)
    return out  # pyright: ignore[reportReturnType]

rotate_embeddings_checkpointed(embeddings, rope_encoding)

Memory-efficient differentiable version of rotate_embeddings.

Wrapper for a custom autograd Function that itself wraps rotate_embeddings for forward pass and rotate_embeddings_backward for the backward pass. This setup allows for custom checkpointing that avoids storage of intermediate tensors.

Parameters:
  • embeddings (Tensor) –

    Embeddings tensor to be rotated (usually a query or key tensor) of real dtype and shape [..., n_heads, head_dim]

  • rope_encoding (Tensor) –

    Position encoding of real dtype and shape [..., n_heads, head_dim/2] or [..., 1, head_dim/2] (broadcasted over heads)

Returns:
  • embeddings_rotated( Tensor ) –

    Embedding tensor after rotation, of shape [..., n_heads, head_dim] and real dtype

Source code in nd_rotary_encodings/functional/autograd.py
def rotate_embeddings_checkpointed(embeddings: Tensor, rope_encoding: Tensor) -> Tensor:
    """Memory-efficient differentiable version of ``rotate_embeddings``.

    Wrapper for a custom autograd Function that itself wraps `rotate_embeddings` for
    forward pass and `rotate_embeddings_backward` for the backward pass. This setup
    allows for custom checkpointing that avoids storage of intermediate tensors.

    Args:
        embeddings (Tensor): Embeddings tensor to be rotated (usually a query or
            key tensor) of real dtype and shape [..., n_heads, head_dim]
        rope_encoding (Tensor): Position encoding of real dtype and shape
            [..., n_heads, head_dim/2] or
            [..., 1,       head_dim/2] (broadcasted over heads)

    Returns:
        embeddings_rotated (Tensor): Embedding tensor after rotation, of shape
            [..., n_heads, head_dim] and real dtype
    """
    out = RotateEmbeddingsFunction.apply(embeddings, rope_encoding)
    return out  # pyright: ignore[reportReturnType]

apply_rope_checkpointed(embeddings, positions, rope_freqs, key_embeddings=None)

End-to-end rotary positional encoding (RoPE) in a single autograd node.

Internally, this function computes the full RoPE encoding tensor and applies it to the embeddings, without storing it for backprop. Since this tensor is potentially very large for large sequence length and/or embedding dimension but is cheap to calculate, this gradient checkpointing logic can trade off potentially significant memory savings for a small computation increase.

Parameters:
  • embeddings (Tensor) –

    Embeddings tensor to be rotated (usually a query or key tensor) of real dtype and shape [..., n_heads, head_dim]

  • positions (Tensor) –

    Position information for each embedding element of shape [..., position_dim], where ... are arbitrary batch dimensions and position_dim is the dimensionality of the position representation.

  • rope_freqs (Tensor) –

    Frequency values for rotary encodings of shape [position_dim, n_freq_groups, n_heads, head_dim/2], where n_freq_groups and n_heads can be 1 for broadcasting.

Returns:
  • embeddings_rotated( Tensor ) –

    Embedding tensor after rotation, of shape [..., n_heads, head_dim] and same dtype as embeddings.

Source code in nd_rotary_encodings/functional/autograd.py
def apply_rope_checkpointed(
    embeddings: Tensor,
    positions: Tensor,
    rope_freqs: Tensor,
    key_embeddings: Optional[Tensor] = None,
) -> Union[Tensor, tuple[Tensor, Tensor]]:
    """End-to-end rotary positional encoding (RoPE) in a single autograd node.

    Internally, this function computes the full RoPE encoding tensor and applies it
    to the embeddings, without storing it for backprop. Since this tensor is potentially
    very large for large sequence length and/or embedding dimension but is cheap to
    calculate, this gradient checkpointing logic can trade off potentially significant
    memory savings for a small computation increase.

    Args:
        embeddings (Tensor): Embeddings tensor to be rotated (usually a query or
            key tensor) of real dtype and shape [..., n_heads, head_dim]
        positions (Tensor): Position information for each embedding element of shape
            [..., position_dim], where ... are arbitrary batch dimensions and
            position_dim is the dimensionality of the position representation.
        rope_freqs (Tensor): Frequency values for rotary encodings of shape
            [position_dim, n_freq_groups, n_heads, head_dim/2], where n_freq_groups
            and n_heads can be 1 for broadcasting.

    Returns:
        embeddings_rotated (Tensor): Embedding tensor after rotation, of shape
            [..., n_heads, head_dim] and same dtype as `embeddings`.
    """
    out = ApplyRoPEFunction.apply(embeddings, positions, rope_freqs, key_embeddings)
    assert out is not None
    if key_embeddings is None:
        return out[0]
    return out  # pyright: ignore[reportReturnType]

Forward-only wrappers

These wrappers enable additional marginal optimizations suitable when only forward (inference) mode is needed.

rotate_embeddings_forward_only(embeddings, rope_encoding, inplace=False)

Forward-only version of rotate_embeddings.

This calls rotate_embeddings with additional optimizations like in-place tensor operations that make it incompatible with autograd.

Parameters:
  • embeddings (Tensor) –

    Embeddings tensor to be rotated (usually a query or key tensor) of real dtype and shape [..., n_heads, head_dim]

  • rope_encoding (Tensor) –

    Position encoding of real dtype and shape [..., n_heads, head_dim/2] or [..., 1, head_dim/2] (broadcasted over heads)

  • inplace (bool, default: False ) –

    (bool): If True, the supplied embeddings tensor is rotated in-place (i.e., overwritten with the new values) for maximum memory efficiency. Default: False.

Returns:
  • embeddings_rotated( Tensor ) –

    Embedding tensor after rotation, of shape [..., n_heads, head_dim] and real dtype

Source code in nd_rotary_encodings/functional/forward_only.py
@torch.no_grad()
def rotate_embeddings_forward_only(
    embeddings: Tensor, rope_encoding: Tensor, inplace: bool = False
) -> Tensor:
    """Forward-only version of ``rotate_embeddings``.

    This calls `rotate_embeddings` with additional optimizations like in-place tensor
    operations that make it incompatible with autograd.

    Args:
        embeddings (Tensor): Embeddings tensor to be rotated (usually a query or
            key tensor) of real dtype and shape [..., n_heads, head_dim]
        rope_encoding (Tensor): Position encoding of real dtype and shape
            [..., n_heads, head_dim/2] or
            [..., 1,       head_dim/2] (broadcasted over heads)
        inplace: (bool): If True, the supplied `embeddings` tensor is rotated in-place
            (i.e., overwritten with the new values) for maximum memory efficiency.
            Default: False.

    Returns:
        embeddings_rotated (Tensor): Embedding tensor after rotation, of shape
            [..., n_heads, head_dim] and real dtype
    """
    if not inplace:
        embeddings = embeddings.clone()
    return rotate_embeddings(embeddings, rope_encoding, needs_autograd=False)

apply_rope_forward_only(embeddings, positions, rope_freqs, inplace=False, self_attn_key_embeddings=None)

End-to-end rotary positional encoding (RoPE) in a single autograd node.

This calls calculate_rope and rotate_embeddings with additional optimizations like in-place tensor operations that make it incompatible with autograd.

Parameters:
  • embeddings (Tensor) –

    Embeddings tensor to be rotated (usually a query or key tensor) of real dtype and shape [..., n_heads, head_dim]

  • positions (Tensor) –

    Position information for each embedding element of shape [..., position_dim], where ... are arbitrary batch dimensions and position_dim is the dimensionality of the position representation.

  • rope_freqs (Tensor) –

    Frequency values for rotary encodings of shape [position_dim, n_freq_groups, n_heads, head_dim/2], where n_freq_groups and n_heads can be 1 for broadcasting.

  • inplace (bool, default: False ) –

    (bool): If True, the supplied embeddings tensor is rotated in-place (i.e., overwritten with the new values) for maximum memory efficiency. Default: False.

Returns:
  • embeddings_rotated( Tensor ) –

    Embedding tensor after rotation, of shape [..., n_heads, head_dim] and same dtype as embeddings.

Source code in nd_rotary_encodings/functional/forward_only.py
@torch.no_grad()
def apply_rope_forward_only(
    embeddings: Tensor,
    positions: Tensor,
    rope_freqs: Tensor,
    inplace: bool = False,
    self_attn_key_embeddings: Optional[Tensor] = None,
) -> Union[Tensor, tuple[Tensor, Tensor]]:
    """End-to-end rotary positional encoding (RoPE) in a single autograd node.

    This calls `calculate_rope` and `rotate_embeddings` with additional optimizations
    like in-place tensor operations that make it incompatible with autograd.

    Args:
        embeddings (Tensor): Embeddings tensor to be rotated (usually a query or
            key tensor) of real dtype and shape [..., n_heads, head_dim]
        positions (Tensor): Position information for each embedding element of shape
            [..., position_dim], where ... are arbitrary batch dimensions and
            position_dim is the dimensionality of the position representation.
        rope_freqs (Tensor): Frequency values for rotary encodings of shape
            [position_dim, n_freq_groups, n_heads, head_dim/2], where n_freq_groups
            and n_heads can be 1 for broadcasting.
        inplace: (bool): If True, the supplied `embeddings` tensor is rotated in-place
            (i.e., overwritten with the new values) for maximum memory efficiency.
            Default: False.

    Returns:
        embeddings_rotated (Tensor): Embedding tensor after rotation, of shape
            [..., n_heads, head_dim] and same dtype as `embeddings`.
    """
    rope_encoding = calculate_rope(positions, rope_freqs)
    embeddings_rotated = rotate_embeddings_forward_only(
        embeddings, rope_encoding, inplace=inplace
    )

    if self_attn_key_embeddings is None:
        return embeddings_rotated

    key_embeddings_rotated = rotate_embeddings_forward_only(
        self_attn_key_embeddings, rope_encoding, inplace=inplace
    )
    return embeddings_rotated, key_embeddings_rotated