diff --git a/python/mlx/nn/layers/positional_encoding.py b/python/mlx/nn/layers/positional_encoding.py index 6c363e368..032fd0f92 100644 --- a/python/mlx/nn/layers/positional_encoding.py +++ b/python/mlx/nn/layers/positional_encoding.py @@ -24,13 +24,21 @@ class RoPE(Module): implementation which is slightly less efficient. Default: ``False``. base (float, optional): The base used to compute angular frequency for each dimension in the positional encodings. Default: ``10000``. + scale (float, optional): The scale used to scale the positions. Default: ``1.0``. """ - def __init__(self, dims: int, traditional: bool = False, base: float = 10000): + def __init__( + self, + dims: int, + traditional: bool = False, + base: float = 10000, + scale: float = 1.0, + ): super().__init__() self.dims = dims self.traditional = traditional self.base = base + self.scale = scale def _extra_repr(self): return f"{self.dims}, traditional={self.traditional}" @@ -68,7 +76,7 @@ class RoPE(Module): x = mx.reshape(x, (-1, shape[-2], shape[-1])) N = x.shape[1] + offset costheta, sintheta = RoPE.create_cos_sin_theta( - N, self.dims, offset=offset, base=self.base, dtype=x.dtype + N, self.dims, offset=offset, base=self.base, scale=self.scale, dtype=x.dtype ) rope = ( @@ -80,10 +88,15 @@ class RoPE(Module): @staticmethod def create_cos_sin_theta( - N: int, D: int, offset: int = 0, base: float = 10000, dtype=mx.float32 + N: int, + D: int, + offset: int = 0, + base: float = 10000, + scale: float = 1.0, + dtype=mx.float32, ): D = D // 2 - positions = mx.arange(offset, N, dtype=dtype) + positions = mx.arange(offset, N, dtype=dtype) * scale freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D)) theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) return mx.cos(theta), mx.sin(theta) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 204145c01..6557e7dbe 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -749,7 +749,7 @@ class TestNN(mlx_tests.MLXTestCase): self.assertEqual(y.dtype, mx.float32) def test_rope(self): - for kwargs in [{}, {"traditional": False}, {"base": 10000}]: + for kwargs in [{}, {"traditional": False}, {"base": 10000}, {"scale": 0.25}]: rope = nn.RoPE(4, **kwargs) shape = (1, 3, 4) x = mx.random.uniform(shape=shape)