From d40e76809f9ddbd32dcc86a34772392590a7eaca Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 20 Aug 2024 17:37:52 -0700 Subject: [PATCH] Fix rope (#1340) * add test * fix rope * fix test --- mlx/backend/metal/kernels/rope.metal | 23 +++++++++-------------- mlx/backend/metal/rope.cpp | 2 +- mlx/fast.cpp | 2 +- python/tests/test_fast.py | 16 +++++++++++++++- 4 files changed, 26 insertions(+), 17 deletions(-) diff --git a/mlx/backend/metal/kernels/rope.metal b/mlx/backend/metal/kernels/rope.metal index d6f44591e..a38cfcdff 100644 --- a/mlx/backend/metal/kernels/rope.metal +++ b/mlx/backend/metal/kernels/rope.metal @@ -22,23 +22,18 @@ void rope_single_impl( float sintheta = metal::fast::sin(theta); // Compute the input and output indices - uint in_index_1, in_index_2; - uint out_index_1, out_index_2; + uint index_1, index_2; if (traditional) { - out_index_1 = 2 * pos.x + pos.y * stride; - out_index_2 = out_index_1 + 1; - in_index_1 = 2 * pos.x + pos.y * stride; - in_index_2 = in_index_1 + 1; + index_1 = 2 * pos.x + pos.y * stride; + index_2 = index_1 + 1; } else { - out_index_1 = pos.x + pos.y * stride; - out_index_2 = out_index_1 + grid.x; - in_index_1 = pos.x + pos.y * stride; - in_index_2 = in_index_1 + grid.x; + index_1 = pos.x + pos.y * stride; + index_2 = index_1 + grid.x; } // Read and write the output - float x1 = static_cast(in[in_index_1]); - float x2 = static_cast(in[in_index_2]); + float x1 = static_cast(in[index_1]); + float x2 = static_cast(in[index_2]); float rx1; float rx2; if (forward) { @@ -48,8 +43,8 @@ void rope_single_impl( rx1 = x2 * sintheta + x1 * costheta; rx2 = x2 * costheta - x1 * sintheta; } - out[out_index_1] = static_cast(rx1); - out[out_index_2] = static_cast(rx2); + out[index_1] = static_cast(rx1); + out[index_2] = static_cast(rx2); } template diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp index 174f63bbc..d1d07df2c 100644 --- a/mlx/backend/metal/rope.cpp +++ b/mlx/backend/metal/rope.cpp @@ -86,7 +86,7 @@ void RoPE::eval_gpu( MTL::Size group_dims; MTL::Size grid_dims; if (single) { - compute_encoder->setBytes(&out_strides[1], sizeof(size_t), 4); + compute_encoder->setBytes(out_strides, sizeof(size_t), 4); uint32_t dim0 = dims_ / 2; group_dims = get_block_dims(dim0, n_batch, 1); grid_dims = MTL::Size(dim0, n_batch, 1); diff --git a/mlx/fast.cpp b/mlx/fast.cpp index dfe68827f..4a8cf791b 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -340,7 +340,7 @@ array rope( if (inputs.size() == 2 && (inputs[1].ndim() != 1 || inputs[1].shape(0) != dims / 2)) { std::ostringstream msg; - msg << "[rope] freqs must be one dimensional with size " << dims + msg << "[rope] freqs must be one dimensional with size " << dims / 2 << " but got shape " << inputs[1].shape() << "."; throw std::invalid_argument(msg.str()); } diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index c68f7e423..b15b737b3 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -8,7 +8,7 @@ import mlx_tests def rope_orig(x, dims, traditional, base, scale, offset, freqs=None): - N = x.shape[1] + offset + N = x.shape[-2] + offset dtype = x.dtype half_D = dims // 2 positions = mx.arange(offset, N, dtype=dtype) * scale @@ -143,6 +143,20 @@ class TestFast(mlx_tests.MLXTestCase): ) self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + # Test transpose into rope + dims, _, base, scale, offset, traditional = defaults + x = mx.random.uniform(shape=(1, 1, 4, dims)).swapaxes(1, 2) + rx = rope_orig(x, dims, traditional, base, scale, offset) + rx_fast = mx.fast.rope( + 1.0 * x, # multiply here to allow donation + dims, + traditional=traditional, + base=base, + scale=scale, + offset=offset, + ) + self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[mx.float32]) + def test_rope_with_freqs(self): # Check throws T = 4