diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index b1151d9a4..7ba6f5c05 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -293,7 +293,18 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { out.shape().data(), out.shape().size() * sizeof(int), 3); compute_encoder->setBytes( out.strides().data(), out.strides().size() * sizeof(size_t), 4); - compute_encoder->setBytes(&upd_size, sizeof(size_t), 5); + + size_t out_ndim = out.ndim(); + compute_encoder->setBytes(&out_ndim, sizeof(out_ndim), 5); + if (upd_ndim <= 1) { + // Placeholder so Metal doesn't compalain + int shape_ = 0; + compute_encoder->setBytes(&shape_, sizeof(int), 6); + } else { + compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 6); + } + compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 7); + compute_encoder->setBytes(&upd_size, sizeof(size_t), 8); // Set index buffers for (int i = 0; i < nidx; ++i) { diff --git a/mlx/backend/metal/jit/indexing.h b/mlx/backend/metal/jit/indexing.h index 80d2a1e83..2227fa2f1 100644 --- a/mlx/backend/metal/jit/indexing.h +++ b/mlx/backend/metal/jit/indexing.h @@ -38,12 +38,24 @@ constexpr std::string_view scatter_kernels = R"( device mlx_atomic<{1}>* out [[buffer(2)]], const constant int* out_shape [[buffer(3)]], const constant size_t* out_strides [[buffer(4)]], - const constant size_t& upd_size [[buffer(5)]], + const constant size_t& out_ndim [[buffer(5)]], + const constant int* upd_shape [[buffer(6)]], + const constant size_t& upd_ndim [[buffer(7)]], + const constant size_t& upd_size [[buffer(8)]], {5} uint2 gid [[thread_position_in_grid]]) {{ const array idx_buffers = {{ {6} }}; return scatter_1d_index_impl<{1}, {2}, {3}, {4}>( - updates, out, out_shape, out_strides, upd_size, idx_buffers, gid); + updates, + out, + out_shape, + out_strides, + out_ndim, + upd_shape, + upd_ndim, + upd_size, + idx_buffers, + gid); }} [[kernel]] void scatter{0}_{4}( diff --git a/mlx/backend/metal/kernels/scatter.h b/mlx/backend/metal/kernels/scatter.h index 785c6134c..108e40adc 100644 --- a/mlx/backend/metal/kernels/scatter.h +++ b/mlx/backend/metal/kernels/scatter.h @@ -10,7 +10,10 @@ METAL_FUNC void scatter_1d_index_impl( device mlx_atomic* out [[buffer(2)]], const constant int* out_shape [[buffer(3)]], const constant size_t* out_strides [[buffer(4)]], - const constant size_t& upd_size [[buffer(5)]], + const constant size_t& out_ndim [[buffer(5)]], + const constant int* upd_shape [[buffer(6)]], + const constant size_t& upd_ndim [[buffer(7)]], + const constant size_t& upd_size [[buffer(8)]], const thread array& idx_buffers, uint2 gid [[thread_position_in_grid]]) { Op op; @@ -21,7 +24,14 @@ METAL_FUNC void scatter_1d_index_impl( out_idx += idx_val * out_strides[i]; } - op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x); + if (upd_ndim > 1) { + auto out_offset = elem_to_loc(gid.x, upd_shape + 1, out_strides, out_ndim); + out_idx += out_offset; + } else { + out_idx += gid.x; + } + + op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx); } template diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index d15253a25..c996b8d47 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -49,7 +49,7 @@ class TestFFT(mlx_tests.MLXTestCase): self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80) self.check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120) - x = np.fft.rfft(a_np) + x = np.fft.rfft(np.real(a_np)) self.check_mx_np(mx.fft.irfft, np.fft.irfft, x) def test_fftn(self): @@ -75,9 +75,9 @@ class TestFFT(mlx_tests.MLXTestCase): if op in ["rfft2", "rfftn"]: x = r elif op == "irfft2": - x = np.ascontiguousarray(np.fft.rfft2(x, axes=ax, s=s)) + x = np.ascontiguousarray(np.fft.rfft2(r, axes=ax, s=s)) elif op == "irfftn": - x = np.ascontiguousarray(np.fft.rfftn(x, axes=ax, s=s)) + x = np.ascontiguousarray(np.fft.rfftn(r, axes=ax, s=s)) mx_op = getattr(mx.fft, op) np_op = getattr(np.fft, op) self.check_mx_np(mx_op, np_op, x, axes=ax, s=s) @@ -93,7 +93,7 @@ class TestFFT(mlx_tests.MLXTestCase): self.check_mx_np(mx.fft.rfft, np.fft.rfft, r, atol=atol, rtol=rtol) - ia_np = np.fft.rfft(a_np) + ia_np = np.fft.rfft(r) self.check_mx_np( mx.fft.irfft, np.fft.irfft, ia_np, atol=atol, rtol=rtol, n=shape[-1] ) diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index f93341f96..fb4fea150 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2062,6 +2062,18 @@ TEST_CASE("test scatter") { auto out = scatter(dst, idx, src, 0); CHECK(array_equal(out, expected).item()); } + + // 1D indices with 2D update + { + auto dst = zeros({3, 4}, int32); + auto indices = {array({1}), array({2})}; + auto axes = {0, 1}; + auto updates = reshape(array({1, 2, 3, 4}, int32), {1, 2, 2}); + auto out = scatter(dst, indices, updates, axes); + auto expected = + reshape(array({0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4}), {3, 4}); + CHECK(array_equal(out, expected).item()); + } } TEST_CASE("test is positive infinity") {