diff --git a/mlx/array.h b/mlx/array.h index 6f26c3314..3801fa1d0 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -154,8 +154,8 @@ class array { }; private: - int idx; const array& arr; + int idx; }; ArrayIterator begin() const { diff --git a/python/mlx/nn/layers/positional_encoding.py b/python/mlx/nn/layers/positional_encoding.py index 5c3171fde..b121a5436 100644 --- a/python/mlx/nn/layers/positional_encoding.py +++ b/python/mlx/nn/layers/positional_encoding.py @@ -18,15 +18,18 @@ class RoPE(Module): Args: dims (int): The feature dimensions to be rotated. If the input feature - is larger than dims then the rest is left unchanged. - traditional (bool): If set to True choose the traditional - implementation which is slightly less efficient. + is larger than dims then the rest is left unchanged. + traditional (bool, optional): If set to True choose the traditional + implementation which is slightly less efficient. Default: ``False`` + base (float, optional): The base used to compute angular frequency for + each dimension in the positional encodings. Default: ``10000`` """ - def __init__(self, dims: int, traditional: bool = False): + def __init__(self, dims: int, traditional: bool = False, base: float = 10000): super().__init__() self.dims = dims self.traditional = traditional + self.base = base def _extra_repr(self): return f"{self.dims}, traditional={self.traditional}" @@ -64,7 +67,7 @@ class RoPE(Module): x = mx.reshape(x, (-1, shape[-2], shape[-1])) N = x.shape[1] + offset costheta, sintheta = RoPE.create_cos_sin_theta( - N, self.dims, offset=offset, dtype=x.dtype + N, self.dims, offset=offset, base=self.base, dtype=x.dtype ) rope = ( @@ -82,10 +85,7 @@ class RoPE(Module): positions = mx.arange(offset, N, dtype=dtype) freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D)) theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) - costheta = mx.cos(theta) - sintheta = mx.sin(theta) - - return costheta, sintheta + return mx.cos(theta), mx.sin(theta) class SinusoidalPositionalEncoding(Module): diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index f5597474d..259795654 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -463,6 +463,21 @@ class TestNN(mlx_tests.MLXTestCase): mx.array([0.8651, -0.3034, 0.0000, 0.3752]), ) + def test_rope(self): + for kwargs in [{}, {"traditional": False}, {"base": 10000}]: + rope = nn.RoPE(4, **kwargs) + shape = (1, 3, 4) + x = mx.random.uniform(shape=shape) + y = rope(x) + self.assertTrue(y.shape, shape) + self.assertTrue(y.dtype, mx.float32) + + y = rope(x, offset=3) + self.assertTrue(y.shape, shape) + + y = rope(x.astype(mx.float16)) + self.assertTrue(y.dtype, mx.float16) + if __name__ == "__main__": unittest.main()