diff --git a/python/mlx/nn/layers/positional_encoding.py b/python/mlx/nn/layers/positional_encoding.py index 4852df546..32d5e2c88 100644 --- a/python/mlx/nn/layers/positional_encoding.py +++ b/python/mlx/nn/layers/positional_encoding.py @@ -23,6 +23,7 @@ class RoPE(Module): implementation which is slightly less efficient. Default: ``False`` base (float, optional): The base used to compute angular frequency for each dimension in the positional encodings. Default: ``10000`` + scale (float, optional): The scale used to scale the positions. Default: ``1.0`` """ def __init__(