mlx.nn.RoPE#

class mlx.nn.RoPE(dims: int, traditional: bool = False)#

Implements the rotary positional encoding [1].

The traditional implementation rotates consecutive pairs of elements in the feature dimension while the default implementation rotates pairs with stride half the feature dimensions for efficiency.

[1]: https://arxiv.org/abs/2104.09864

Parameters:
  • dims (int) – The feature dimensions to be rotated. If the input feature is larger than dims then the rest is left unchanged.

  • traditional (bool) – If set to True choose the traditional implementation which is slightly less efficient.