diff --git a/mlx/fast.cpp b/mlx/fast.cpp index b51c747d8..378bb22ce 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -416,23 +416,25 @@ array rope( if (offset.size() > 1) { offset = expand_dims(offset, {-1, -2}, s); } - auto positions = - multiply(add(arange(x.shape(2), t, s), offset, s), array(scale, t), s); + auto positions = multiply( + add(arange(x.shape(2), float32, s), offset, s), + array(scale, float32), + s); - auto default_inv_freqs = [&s, &t, base, half_dims]() { + auto default_inv_freqs = [&s, base, half_dims]() { return exp( multiply( - arange(0, -half_dims, -1, t, s), - array(std::log(base) / half_dims, t), + arange(0, -half_dims, -1, float32, s), + array(std::log(base) / half_dims, float32), s), s); }; - auto inv_freqs = inputs.size() == 3 ? astype(reciprocal(inputs[2], s), t, s) - : default_inv_freqs(); + auto inv_freqs = + inputs.size() == 3 ? reciprocal(inputs[2], s) : default_inv_freqs(); auto theta = multiply(expand_dims(positions, -1, s), inv_freqs, s); - auto coss = cos(theta, s); - auto sins = sin(theta, s); + auto coss = astype(cos(theta, s), t, s); + auto sins = astype(sin(theta, s), t, s); auto apply_rope = [forward, s]( const array& x1, diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index 4cfef7b11..f16aee05d 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -332,6 +332,26 @@ class TestFast(mlx_tests.MLXTestCase): rx = rope_orig(x, dims, traditional, base, scale, offset) self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5) + def test_rope_with_large_offset(self): + x = mx.random.normal(shape=(1, 1, 1024, 32)) + rx_fp32 = mx.fast.rope( + x, + 32, + traditional=False, + scale=1.0, + base=10000, + offset=4000, + ) + rx_bf16 = mx.fast.rope( + x.astype(mx.bfloat16), + 32, + traditional=False, + scale=1.0, + base=10000, + offset=4000, + ) + self.assertLess((rx_fp32 - rx_bf16).abs().max(), 1e-1) + def test_rms_norm(self): # Per dtype absolute tolerance tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}