From 06072601cef7876a23d079d9f8c28df28f6e28e8 Mon Sep 17 00:00:00 2001 From: Vijay Krish Date: Sat, 10 Feb 2024 08:49:51 -0800 Subject: [PATCH] Scatter optimization : Eliminate 64b integer divide. (#662) Launch 2D grid to eliminate divide and mod in device code, since 64b integer division is very expensive. Github Issue #506 Co-authored-by: Vijay Krishnamoorthy --- benchmarks/python/gather_bench.py | 13 +----- benchmarks/python/scatter_bench.py | 56 ++++++++++++++++++++++++ benchmarks/python/time_utils.py | 14 +++++- mlx/backend/metal/indexing.cpp | 5 +-- mlx/backend/metal/kernels/indexing.metal | 10 ++--- 5 files changed, 77 insertions(+), 21 deletions(-) create mode 100644 benchmarks/python/scatter_bench.py diff --git a/benchmarks/python/gather_bench.py b/benchmarks/python/gather_bench.py index 8ce6dcb8f..e000841d2 100644 --- a/benchmarks/python/gather_bench.py +++ b/benchmarks/python/gather_bench.py @@ -5,18 +5,7 @@ from time import time import mlx.core as mx import torch - - -def measure_runtime(fn, **kwargs): - # Warmup - for _ in range(5): - fn(**kwargs) - - tic = time() - iters = 10 - for _ in range(iters): - fn(**kwargs) - return (time() - tic) * 1000 / iters +from time_utils import measure_runtime def benchmark_gather_mlx(x_shape, idx_shape): diff --git a/benchmarks/python/scatter_bench.py b/benchmarks/python/scatter_bench.py new file mode 100644 index 000000000..2d63d8bf1 --- /dev/null +++ b/benchmarks/python/scatter_bench.py @@ -0,0 +1,56 @@ +# Copyright © 2023-2024 Apple Inc. + +import argparse + +import mlx.core as mx +import torch +from time_utils import measure_runtime + + +def benchmark_scatter_mlx(dst_shape, x_shape, idx_shape): + def scatter(dst, x, idx): + dst[idx] = x + mx.eval(dst) + + idx = mx.random.randint(0, dst_shape[0] - 1, idx_shape) + x = mx.random.normal(x_shape).astype(mx.float32) + dst = mx.random.normal(dst_shape).astype(mx.float32) + + runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx) + print(f"MLX: {runtime:.3f}ms") + + +def benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device): + def gather(dst, x, idx, device): + dst[idx] = x + if device == torch.device("mps"): + torch.mps.synchronize() + + idx = torch.randint(0, dst_shape[0] - 1, idx_shape).to(device) + x = torch.randn(x_shape, dtype=torch.float32).to(device) + dst = torch.randn(dst_shape, dtype=torch.float32).to(device) + + runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device) + print(f"PyTorch: {runtime:.3f}ms") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Gather benchmarks.") + parser.add_argument("--cpu", action="store_true", help="Use the CPU.") + args = parser.parse_args() + + if args.cpu: + mx.set_default_device(mx.cpu) + device = torch.device("cpu") + else: + device = torch.device("mps") + + dst_shapes = [(10, 64), (100_000, 64), (1_000_000, 64)] + idx_shapes = [(1_000_000,), (1_000_000,), (100_000,)] + x_shapes = [(1_000_000, 64), (1_000_000, 64), (100_000, 64)] + + for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes): + print("=" * 20) + print(f"X {x_shape}, Indices {idx_shape}") + benchmark_scatter_mlx(dst_shape, x_shape, idx_shape) + benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device) diff --git a/benchmarks/python/time_utils.py b/benchmarks/python/time_utils.py index 73266cb7a..32d22ed99 100644 --- a/benchmarks/python/time_utils.py +++ b/benchmarks/python/time_utils.py @@ -1,4 +1,4 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. import time @@ -20,3 +20,15 @@ def time_fn(fn, *args, **kwargs): msec = 1e3 * (toc - tic) / num_iters print(f"{msec:.5f} msec") + + +def measure_runtime(fn, **kwargs): + # Warmup + for _ in range(5): + fn(**kwargs) + + tic = time.time() + iters = 10 + for _ in range(iters): + fn(**kwargs) + return (time.time() - tic) * 1000 / iters diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index 28edb9126..cf2256846 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -212,9 +212,6 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { thread_group_size = nthreads; } - MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); - MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - compute_encoder->setComputePipelineState(kernel); // Make the argument buffer to store the indices for the @@ -317,6 +314,8 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9); compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10); + MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1); + MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1); compute_encoder->dispatchThreads(grid_dims, group_dims); // Cleanup temporaries diff --git a/mlx/backend/metal/kernels/indexing.metal b/mlx/backend/metal/kernels/indexing.metal index 395bc7819..7b6e2399a 100644 --- a/mlx/backend/metal/kernels/indexing.metal +++ b/mlx/backend/metal/kernels/indexing.metal @@ -187,11 +187,11 @@ template const device size_t *out_strides [[buffer(8)]], const device size_t& out_ndim [[buffer(9)]], const device int* axes [[buffer(10)]], - uint gid [[thread_position_in_grid]]) { + uint2 gid [[thread_position_in_grid]]) { Op op; - auto ind_idx = gid / upd_size; - auto ind_offset = gid % upd_size; + auto ind_idx = gid.y; + auto ind_offset = gid.x; size_t out_idx = 0; for (int i = 0; i < NIDX; ++i) { @@ -208,7 +208,7 @@ template auto out_offset = elem_to_loc( ind_offset, upd_shape + indices.ndim, out_strides, out_ndim); - auto upd_idx = elem_to_loc(gid, upd_shape, upd_strides, upd_ndim); + auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim); op.atomic_update(out, updates[upd_idx], out_idx + out_offset); } @@ -226,7 +226,7 @@ template [[host_name("scatter" name "_" #nindex)]] \ const device size_t *out_strides [[buffer(8)]], \ const device size_t& out_ndim [[buffer(9)]], \ const device int* axes [[buffer(10)]], \ - uint gid [[thread_position_in_grid]]); + uint2 gid [[thread_position_in_grid]]); // Special case NINDEX=0 #define instantiate_scatter_nd0(name, type) \