diff --git a/benchmarks/python/scatter_bench.py b/benchmarks/python/scatter_bench.py index 2d63d8bf1..d2fd569ac 100644 --- a/benchmarks/python/scatter_bench.py +++ b/benchmarks/python/scatter_bench.py @@ -7,12 +7,14 @@ import torch from time_utils import measure_runtime -def benchmark_scatter_mlx(dst_shape, x_shape, idx_shape): +def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes): def scatter(dst, x, idx): - dst[idx] = x + dst[*idx] = x mx.eval(dst) - idx = mx.random.randint(0, dst_shape[0] - 1, idx_shape) + idx = [] + for idx_shape in idx_shapes: + idx.append(mx.random.randint(0, dst_shape[0] - 1, idx_shape)) x = mx.random.normal(x_shape).astype(mx.float32) dst = mx.random.normal(dst_shape).astype(mx.float32) @@ -20,13 +22,15 @@ def benchmark_scatter_mlx(dst_shape, x_shape, idx_shape): print(f"MLX: {runtime:.3f}ms") -def benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device): +def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device): def gather(dst, x, idx, device): - dst[idx] = x + dst[*idx] = x if device == torch.device("mps"): torch.mps.synchronize() - idx = torch.randint(0, dst_shape[0] - 1, idx_shape).to(device) + idx = [] + for idx_shape in idx_shapes: + idx.append(torch.randint(0, dst_shape[0] - 1, idx_shape).to(device)) x = torch.randn(x_shape, dtype=torch.float32).to(device) dst = torch.randn(dst_shape, dtype=torch.float32).to(device) @@ -45,9 +49,45 @@ if __name__ == "__main__": else: device = torch.device("mps") - dst_shapes = [(10, 64), (100_000, 64), (1_000_000, 64)] - idx_shapes = [(1_000_000,), (1_000_000,), (100_000,)] - x_shapes = [(1_000_000, 64), (1_000_000, 64), (100_000, 64)] + dst_shapes = [ + (10, 64), + (100_000, 64), + (1_000_000, 64), + (100_000,), + (2_000_00,), + (20_000_000,), + (10000, 64), + (100, 64), + (100, 10_000, 64), + (10, 100, 100, 21), + (1_000, 1_000, 10), + ] + idx_shapes = [ + [(1_000_000,)], + [(1_000_000,)], + [(100_000,)], + [(1_000_000,)], + [(20_000_000,)], + [(20_000_000,)], + [(1000000,)], + [(10000000,)], + [(1_000,)], + [(10_000,)], + [(1_000,), (1_000,)], + ] + x_shapes = [ + (1_000_000, 64), + (1_000_000, 64), + (100_000, 64), + (1_000_000,), + (20_000_000,), + (20_000_000,), + (1000000, 64), + (10000000, 64), + (1_000, 10_000, 64), + (10_000, 100, 100, 21), + (1_000, 10), + ] for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes): print("=" * 20) diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index 4eeb8858e..56a312a5d 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -142,7 +142,28 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { // Get kernel name std::ostringstream kname; std::string idx_type_name = nidx ? type_to_name(inputs[1]) : ""; - kname << "scatter" << type_to_name(out) << idx_type_name; + + int idx_ndim = nidx ? inputs[1].ndim() : 0; + bool index_nd1_specialization = (idx_ndim == 1); + + // Bail from fast path (1d index specialization) if scatter dims aren't + // the outermost dims and contiguous since update access won't be raster + // order. + for (auto i = 0; i < axes_.size() && index_nd1_specialization; i++) { + index_nd1_specialization &= (axes_[i] == i); + } + + // Bail from fast path (1d index specialization) if any of the dims are + // broadcasted, since we can't rely on linear indexing in that case. + for (int i = 1; i < inputs.size() && index_nd1_specialization; i++) { + index_nd1_specialization &= inputs[i].flags().row_contiguous; + } + + if (index_nd1_specialization) { + kname << "scatter_1d_index" << type_to_name(out) << idx_type_name; + } else { + kname << "scatter" << type_to_name(out) << idx_type_name; + } switch (reduce_type_) { case Scatter::None: kname << "_none"; @@ -170,85 +191,106 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setComputePipelineState(kernel); - // Collect all idx shapes and strides into one place - int idx_ndim = nidx ? inputs[1].ndim() : 0; - std::vector idx_shapes; - std::vector idx_strides; - - for (int i = 0; i < nidx; ++i) { - idx_shapes.insert( - idx_shapes.end(), - inputs[i + 1].shape().begin(), - inputs[i + 1].shape().end()); - - idx_strides.insert( - idx_strides.end(), - inputs[i + 1].strides().begin(), - inputs[i + 1].strides().end()); - } - // Set all the buffers set_array_buffer(compute_encoder, upd, 1); set_array_buffer(compute_encoder, out, 2); // Set update info - size_t upd_ndim = upd.ndim(); + uint upd_ndim = upd.ndim(); size_t upd_size = 1; for (int i = idx_ndim; i < upd.ndim(); ++i) { upd_size *= upd.shape(i); } - if (upd_ndim == 0) { - // Need placeholders so Metal doesn't compalain - int shape_ = 0; - size_t stride_ = 0; - compute_encoder->setBytes(&shape_, sizeof(int), 3); - compute_encoder->setBytes(&stride_, sizeof(size_t), 4); - } else { - compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3); + + if (index_nd1_specialization) { + bool upd_col_contiguous = upd.flags().col_contiguous; compute_encoder->setBytes( - upd.strides().data(), upd_ndim * sizeof(size_t), 4); - } - compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5); - compute_encoder->setBytes(&upd_size, sizeof(size_t), 6); - - // Set output info - size_t out_ndim = out.ndim(); - if (out_ndim == 0) { - // Need placeholders so Metal doesn't compalain - int shape_ = 0; - size_t stride_ = 0; - compute_encoder->setBytes(&shape_, sizeof(int), 7); - compute_encoder->setBytes(&stride_, sizeof(size_t), 8); - } else { - compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7); + out.shape().data(), out.shape().size() * sizeof(int), 3); compute_encoder->setBytes( - out.strides().data(), out_ndim * sizeof(size_t), 8); - } - compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9); - compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10); + out.strides().data(), out.strides().size() * sizeof(size_t), 4); + compute_encoder->setBytes(&upd_size, sizeof(size_t), 5); + compute_encoder->setBytes(&upd_col_contiguous, sizeof(bool), 6); - // Set index info - if (idx_ndim == 0) { - // Add a 0 in idx_shapes and strides to avoid the missing buffer binding - // error in the metal API. - idx_shapes.push_back(0); - idx_strides.push_back(0); - } - compute_encoder->setBytes( - idx_shapes.data(), idx_shapes.size() * sizeof(int), 11); - compute_encoder->setBytes( - idx_strides.data(), idx_strides.size() * sizeof(size_t), 12); - compute_encoder->setBytes(&idx_ndim, sizeof(int), 13); + // Set index buffers + for (int i = 1; i < nidx + 1; ++i) { + set_array_buffer(compute_encoder, inputs[i], 20 + i); + } - // Set index buffers - for (int i = 1; i < nidx + 1; ++i) { - set_array_buffer(compute_encoder, inputs[i], 20 + i); - } + // Launch grid + MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1); + MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1); + compute_encoder->dispatchThreads(grid_dims, group_dims); - // Launch grid - MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1); - MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1); - compute_encoder->dispatchThreads(grid_dims, group_dims); + } else { + // Collect all idx shapes and strides into one place + std::vector idx_shapes; + std::vector idx_strides; + + for (int i = 0; i < nidx; ++i) { + idx_shapes.insert( + idx_shapes.end(), + inputs[i + 1].shape().begin(), + inputs[i + 1].shape().end()); + + idx_strides.insert( + idx_strides.end(), + inputs[i + 1].strides().begin(), + inputs[i + 1].strides().end()); + } + + if (upd_ndim == 0) { + // Need placeholders so Metal doesn't compalain + int shape_ = 0; + size_t stride_ = 0; + compute_encoder->setBytes(&shape_, sizeof(int), 3); + compute_encoder->setBytes(&stride_, sizeof(size_t), 4); + } else { + compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3); + compute_encoder->setBytes( + upd.strides().data(), upd_ndim * sizeof(size_t), 4); + } + compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5); + compute_encoder->setBytes(&upd_size, sizeof(size_t), 6); + + // Set output info + size_t out_ndim = out.ndim(); + if (out_ndim == 0) { + // Need placeholders so Metal doesn't compalain + int shape_ = 0; + size_t stride_ = 0; + compute_encoder->setBytes(&shape_, sizeof(int), 7); + compute_encoder->setBytes(&stride_, sizeof(size_t), 8); + } else { + compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7); + compute_encoder->setBytes( + out.strides().data(), out_ndim * sizeof(size_t), 8); + } + compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9); + compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10); + + // Set index info + if (idx_ndim == 0) { + // Add a 0 in idx_shapes and strides to avoid the missing buffer binding + // error in the metal API. + idx_shapes.push_back(0); + idx_strides.push_back(0); + } + compute_encoder->setBytes( + idx_shapes.data(), idx_shapes.size() * sizeof(int), 11); + compute_encoder->setBytes( + idx_strides.data(), idx_strides.size() * sizeof(size_t), 12); + compute_encoder->setBytes(&idx_ndim, sizeof(int), 13); + + // Set index buffers + for (int i = 1; i < nidx + 1; ++i) { + set_array_buffer(compute_encoder, inputs[i], 20 + i); + } + + // Launch grid + MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1); + MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1); + compute_encoder->dispatchThreads(grid_dims, group_dims); + } } } // namespace mlx::core diff --git a/mlx/backend/metal/kernels/scatter.metal b/mlx/backend/metal/kernels/scatter.metal index 7a94be7da..071effeea 100644 --- a/mlx/backend/metal/kernels/scatter.metal +++ b/mlx/backend/metal/kernels/scatter.metal @@ -13,6 +13,58 @@ using namespace metal; // Scatter kernel ///////////////////////////////////////////////////////////////////// +template \ +METAL_FUNC void scatter_1d_index_impl( + const device T *updates [[buffer(1)]], + device mlx_atomic *out [[buffer(2)]], + const constant int* out_shape [[buffer(3)]], + const constant size_t* out_strides [[buffer(4)]], + const constant size_t& upd_size [[buffer(5)]], + const constant bool& upd_col_contiguous [[buffer(6)]], + const thread array& idx_buffers, + uint2 gid [[thread_position_in_grid]]) { + + Op op; + + uint out_idx = 0; + for (int i = 0; i < NIDX; i++) { + auto idx_val = offset_neg_idx( + idx_buffers[i][gid.y], out_shape[i]); + out_idx += idx_val * out_strides[i]; + } + + if (!upd_col_contiguous) { + op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x); + } else { + op.atomic_update(out, updates[gid.x * upd_size + gid.y], out_idx + gid.x); + } +} + +#define make_scatter_1d_index(IDX_ARG, IDX_ARR) \ +template \ +[[kernel]] void scatter_1d_index( \ + const device T *updates [[buffer(1)]], \ + device mlx_atomic *out [[buffer(2)]], \ + const constant int* out_shape [[buffer(3)]], \ + const constant size_t* out_strides [[buffer(4)]], \ + const constant size_t& upd_size [[buffer(5)]], \ + const constant bool& upd_col_contiguous [[buffer(6)]], \ + IDX_ARG(IdxT) \ + uint2 gid [[thread_position_in_grid]]) { \ + \ + const array idx_buffers = {IDX_ARR()}; \ + \ + return scatter_1d_index_impl( \ + updates, \ + out, \ + out_shape, \ + out_strides, \ + upd_size, \ + upd_col_contiguous, \ + idx_buffers, \ + gid); \ + \ +} template METAL_FUNC void scatter_impl( @@ -46,10 +98,14 @@ METAL_FUNC void scatter_impl( out_idx += idx_val * out_strides[ax]; } - auto out_offset = elem_to_loc( - ind_offset, upd_shape + indices.ndim, out_strides, out_ndim); + if (upd_size > 1) { + auto out_offset = elem_to_loc( + ind_offset, upd_shape + indices.ndim, out_strides, out_ndim); + out_idx += out_offset; + } + auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim); - op.atomic_update(out, updates[upd_idx], out_idx + out_offset); + op.atomic_update(out, updates[upd_idx], out_idx); } #define make_scatter_impl(IDX_ARG, IDX_ARR) \ @@ -90,9 +146,11 @@ template \ axes, \ idxs, \ gid); \ -} +} -#define make_scatter(n) make_scatter_impl(IDX_ARG_ ##n, IDX_ARR_ ##n) +#define make_scatter(n) \ +make_scatter_impl(IDX_ARG_ ##n, IDX_ARR_ ##n) \ +make_scatter_1d_index(IDX_ARG_ ##n, IDX_ARR_ ##n) make_scatter(0) make_scatter(1) @@ -129,8 +187,21 @@ template [[host_name("scatter" name "_" #nidx)]] \ IDX_ARG(idx_t) \ uint2 gid [[thread_position_in_grid]]); +#define instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG) \ +template [[host_name("scatter_1d_index" name "_" #nidx)]] \ +[[kernel]] void scatter_1d_index( \ + const device src_t *updates [[buffer(1)]], \ + device mlx_atomic *out [[buffer(2)]], \ + const constant int* out_shape [[buffer(3)]], \ + const constant size_t* out_strides [[buffer(4)]], \ + const constant size_t& upd_size [[buffer(5)]], \ + const constant bool& upd_col_contiguous [[buffer(6)]], \ + IDX_ARG(idx_t) \ + uint2 gid [[thread_position_in_grid]]); + #define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \ - instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) + instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) \ + instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) // Special case NINDEX=0 #define instantiate_scatter_nd0(name, type) \ diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index ba4ab552f..67ff71b12 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1858,6 +1858,14 @@ TEST_CASE("test scatter") { out = scatter(in, inds, updates, 0); CHECK(array_equal(out, reshape(arange(16, float32), {4, 4})).item()); + // Array scatters with col contiguous updates + in = zeros({4, 4}, float32); + inds = array({0, 1, 2, 3}); + updates = transpose(reshape(arange(16, float32), {4, 1, 4})); + out = scatter(in, inds, updates, 0); + CHECK(array_equal(out, transpose(reshape(arange(16, float32), {4, 4}))) + .item()); + // Irregular strided index and reduce collision test in = zeros({10}, float32); inds = broadcast_to(array(3), {10}); @@ -1877,10 +1885,10 @@ TEST_CASE("test scatter") { // Irregularly strided updates test in = ones({3, 3}); - updates = broadcast_to(array({0, 0, 0}), {1, 3, 3}); + updates = broadcast_to(array({2, 2, 2}), {1, 3, 3}); inds = array({0}); out = scatter(in, inds, updates, 0); - CHECK(array_equal(out, zeros({3, 3})).item()); + CHECK(array_equal(out, ones({3, 3}) * 2).item()); // Along different axis in = zeros({2, 3});