From 8b7532b9ab1db87504c4cc498806f842d684a7a0 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 12 Mar 2024 11:42:07 -0700 Subject: [PATCH] fix scatter (#821) --- mlx/backend/metal/indexing.cpp | 3 --- mlx/backend/metal/kernels/scatter.metal | 10 +--------- tests/ops_tests.cpp | 12 +++++++++++- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index 56a312a5d..13169c4f1 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -201,15 +201,12 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { for (int i = idx_ndim; i < upd.ndim(); ++i) { upd_size *= upd.shape(i); } - if (index_nd1_specialization) { - bool upd_col_contiguous = upd.flags().col_contiguous; compute_encoder->setBytes( out.shape().data(), out.shape().size() * sizeof(int), 3); compute_encoder->setBytes( out.strides().data(), out.strides().size() * sizeof(size_t), 4); compute_encoder->setBytes(&upd_size, sizeof(size_t), 5); - compute_encoder->setBytes(&upd_col_contiguous, sizeof(bool), 6); // Set index buffers for (int i = 1; i < nidx + 1; ++i) { diff --git a/mlx/backend/metal/kernels/scatter.metal b/mlx/backend/metal/kernels/scatter.metal index d8cec1336..fd2bd1950 100644 --- a/mlx/backend/metal/kernels/scatter.metal +++ b/mlx/backend/metal/kernels/scatter.metal @@ -20,7 +20,6 @@ METAL_FUNC void scatter_1d_index_impl( const constant int* out_shape [[buffer(3)]], const constant size_t* out_strides [[buffer(4)]], const constant size_t& upd_size [[buffer(5)]], - const constant bool& upd_col_contiguous [[buffer(6)]], const thread array& idx_buffers, uint2 gid [[thread_position_in_grid]]) { @@ -33,11 +32,7 @@ METAL_FUNC void scatter_1d_index_impl( out_idx += idx_val * out_strides[i]; } - if (!upd_col_contiguous) { - op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x); - } else { - op.atomic_update(out, updates[gid.x * upd_size + gid.y], out_idx + gid.x); - } + op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x); } #define make_scatter_1d_index(IDX_ARG, IDX_ARR) \ @@ -48,7 +43,6 @@ template \ const constant int* out_shape [[buffer(3)]], \ const constant size_t* out_strides [[buffer(4)]], \ const constant size_t& upd_size [[buffer(5)]], \ - const constant bool& upd_col_contiguous [[buffer(6)]], \ IDX_ARG(IdxT) \ uint2 gid [[thread_position_in_grid]]) { \ \ @@ -60,7 +54,6 @@ template \ out_shape, \ out_strides, \ upd_size, \ - upd_col_contiguous, \ idx_buffers, \ gid); \ \ @@ -195,7 +188,6 @@ template [[host_name("scatter_1d_index" name "_" #nidx)]] \ const constant int* out_shape [[buffer(3)]], \ const constant size_t* out_strides [[buffer(4)]], \ const constant size_t& upd_size [[buffer(5)]], \ - const constant bool& upd_col_contiguous [[buffer(6)]], \ IDX_ARG(idx_t) \ uint2 gid [[thread_position_in_grid]]); diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 863b3fd72..953fa7839 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include #include @@ -1919,6 +1919,16 @@ TEST_CASE("test scatter") { inds = array({0, 1}); out = scatter_add(in, inds, updates, 0); CHECK(array_equal(out, array({1, 0, 1, 0}, {2, 2})).item()); + + // 1D scatter + { + auto dst = zeros({2, 4}, int32); + auto src = reshape(array({1, 2, 3, 4}), {1, 1, 4}); + auto idx = array({1}); + auto expected = reshape(array({0, 0, 0, 0, 1, 2, 3, 4}), {2, 4}); + auto out = scatter(dst, idx, src, 0); + CHECK(array_equal(out, expected).item()); + } } TEST_CASE("test is positive infinity") {