From 121d9a07021b245b9adcf6ee9ac417102d565570 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 26 Jan 2025 19:07:21 -0800 Subject: [PATCH] Fix rope fallback to not upcast (#1797) * fix rope fallback to not upcast * Update mlx/fast.cpp Co-authored-by: Angelos Katharopoulos --------- Co-authored-by: Angelos Katharopoulos --- mlx/fast.cpp | 4 ++-- python/tests/test_fast.py | 28 ++++++++++++++++------------ 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 79f73aee6..14195b9e4 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -392,8 +392,8 @@ array rope( s); }; - auto inv_freqs = - inputs.size() == 3 ? reciprocal(inputs[2], s) : default_inv_freqs(); + auto inv_freqs = inputs.size() == 3 ? astype(reciprocal(inputs[2], s), t, s) + : default_inv_freqs(); auto theta = multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s); auto coss = cos(theta, s); diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index 524b43d1b..fd9a87c9c 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -18,7 +18,7 @@ def rope_orig(x, dims, traditional, base, scale, offset, freqs=None): -mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D) ) else: - inv_freqs = 1 / freqs + inv_freqs = (1 / freqs).astype(x.dtype) theta = mx.reshape(positions, (-1, 1)) * mx.reshape(inv_freqs, (1, -1)) costheta, sintheta = mx.cos(theta), mx.sin(theta) if traditional: @@ -189,17 +189,21 @@ class TestFast(mlx_tests.MLXTestCase): freqs = mx.random.uniform(shape=(dims // 2,)) - rx = rope_orig(x, dims, False, None, 1.0, 0, freqs) - rx_fast = mx.fast.rope( - x, - dims, - traditional=False, - base=None, - scale=1.0, - offset=0, - freqs=freqs, - ) - self.assertLess(mx.abs(rx - rx_fast).max(), 1e-5) + tolerances = {mx.float32: 1e-5, mx.float16: 1e-2} + for dtype in [mx.float32, mx.float16]: + x_ = x.astype(dtype) + rx = rope_orig(x_, dims, False, None, 1.0, 0, freqs) + rx_fast = mx.fast.rope( + x_, + dims, + traditional=False, + base=None, + scale=1.0, + offset=0, + freqs=freqs, + ) + self.assertEqual(dtype, rx.dtype) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) # Test single vector x = mx.random.uniform(shape=(1, 1, dims))