mlx.nn.RoPE#

class mlx.nn.RoPE(dims: int, traditional: bool = False, base: float = 10000)#

Implements the rotary positional encoding [1].

The traditional implementation rotates consecutive pairs of elements in the feature dimension while the default implementation rotates pairs with stride half the feature dimensions for efficiency.

[1]: https://arxiv.org/abs/2104.09864

Parameters:
  • dims (int) – The feature dimensions to be rotated. If the input feature is larger than dims then the rest is left unchanged.

  • traditional (bool, optional) – If set to True choose the traditional implementation which is slightly less efficient. Default: False

  • base (float, optional) – The base used to compute angular frequency for each dimension in the positional encodings. Default: 10000