mlx.nn.RoPE#
- class mlx.nn.RoPE(dims: int, traditional: bool = False, base: float = 10000)#
Implements the rotary positional encoding [1].
The traditional implementation rotates consecutive pairs of elements in the feature dimension while the default implementation rotates pairs with stride half the feature dimensions for efficiency.
[1]: https://arxiv.org/abs/2104.09864
- Parameters:
dims (int) – The feature dimensions to be rotated. If the input feature is larger than dims then the rest is left unchanged.
traditional (bool, optional) – If set to True choose the traditional implementation which is slightly less efficient. Default:
False
base (float, optional) – The base used to compute angular frequency for each dimension in the positional encodings. Default:
10000