From e319383ef92c09651c6808b22d1465018e01ea73 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 4 Feb 2024 17:25:44 -0800 Subject: [PATCH] Faster gather (#626) * faster gather * update copyright --- benchmarks/python/gather_bench.py | 64 +++++++++++++++++ mlx/backend/metal/indexing.cpp | 27 +++---- mlx/backend/metal/kernels/indexing.metal | 92 ++++++++++++++++-------- 3 files changed, 142 insertions(+), 41 deletions(-) create mode 100644 benchmarks/python/gather_bench.py diff --git a/benchmarks/python/gather_bench.py b/benchmarks/python/gather_bench.py new file mode 100644 index 000000000..8ce6dcb8f --- /dev/null +++ b/benchmarks/python/gather_bench.py @@ -0,0 +1,64 @@ +# Copyright © 2023-2024 Apple Inc. + +import argparse +from time import time + +import mlx.core as mx +import torch + + +def measure_runtime(fn, **kwargs): + # Warmup + for _ in range(5): + fn(**kwargs) + + tic = time() + iters = 10 + for _ in range(iters): + fn(**kwargs) + return (time() - tic) * 1000 / iters + + +def benchmark_gather_mlx(x_shape, idx_shape): + def gather(x, idx): + mx.eval(x[idx]) + + idx = mx.random.randint(0, x_shape[0] - 1, idx_shape) + x = mx.random.normal(x_shape).astype(mx.float32) + + runtime = measure_runtime(gather, x=x, idx=idx) + print(f"MLX: {runtime:.3f}ms") + + +def benchmark_gather_torch(x_shape, idx_shape, device): + def gather(x, idx, device): + _ = x[idx] + if device == torch.device("mps"): + torch.mps.synchronize() + + idx = torch.randint(0, x_shape[0] - 1, idx_shape).to(device) + x = torch.randn(x_shape, dtype=torch.float32).to(device) + + runtime = measure_runtime(gather, x=x, idx=idx, device=device) + print(f"PyTorch: {runtime:.3f}ms") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Gather benchmarks.") + parser.add_argument("--cpu", action="store_true", help="Use the CPU.") + args = parser.parse_args() + + if args.cpu: + mx.set_default_device(mx.cpu) + device = torch.device("cpu") + else: + device = torch.device("mps") + + idx_shapes = [(1_000_000,), (100_000,), ()] + x_shapes = [(100, 64), (100, 1024), (4, 1_000_000)] + + for x_shape, idx_shape in zip(x_shapes, idx_shapes): + print("=" * 20) + print(f"X {x_shape}, Indices {idx_shape}") + benchmark_gather_mlx(x_shape, idx_shape) + benchmark_gather_torch(x_shape, idx_shape, device=device) diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index ce543bd67..28edb9126 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include #include #include @@ -39,9 +39,15 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& d = metal::device(s.device); + int idx_ndim = nidx ? inputs[1].ndim() : 0; + size_t ndim = src.ndim(); + std::ostringstream kname; std::string idx_type_name = nidx ? type_to_name(inputs[1]) : ""; kname << "gather" << type_to_name(src) << idx_type_name << "_" << nidx; + if (idx_ndim <= 1) { + kname << "_" << idx_ndim; + } auto compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); @@ -51,15 +57,11 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { slice_size *= s; } - size_t ndim = src.ndim(); - size_t nthreads = out.size(); - NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - if (thread_group_size > nthreads) { - thread_group_size = nthreads; - } - - MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); - MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); + // Launch 2D grid of threads: indices x slice + size_t dim0 = out.size() / slice_size; + size_t dim1 = slice_size; + auto group_dims = get_block_dims(dim0, dim1, 1); + MTL::Size grid_dims = MTL::Size(dim0, dim1, 1); compute_encoder->setComputePipelineState(kernel); @@ -90,7 +92,6 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { auto arg_enc = d.argument_encoder(arg_descs); // Allocate and fill buffers for shapes and strides - int idx_ndim = nidx ? inputs[1].ndim() : 0; auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim); auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim); for (int i = 0; i < nidx; ++i) { @@ -130,12 +131,12 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { set_array_buffer(compute_encoder, src, 0); compute_encoder->setBuffer(static_cast(arg_buf.ptr()), 0, 1); set_array_buffer(compute_encoder, out, 2); + compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 3); compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 4); compute_encoder->setBytes(&ndim, sizeof(size_t), 5); compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 6); - compute_encoder->setBytes(&slice_size, sizeof(size_t), 7); - compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 8); + compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 7); compute_encoder->dispatchThreads(grid_dims, group_dims); diff --git a/mlx/backend/metal/kernels/indexing.metal b/mlx/backend/metal/kernels/indexing.metal index 82642d250..395bc7819 100644 --- a/mlx/backend/metal/kernels/indexing.metal +++ b/mlx/backend/metal/kernels/indexing.metal @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include #include @@ -36,29 +36,38 @@ inline size_t offset_neg_idx(uint32_t idx, size_t) { return idx; } -template +// IDX_NDIM is the number of dimensions of the indices arrays. Compile-time +// special case for 0 and 1. Anything >= 2 uses the general case +template [[kernel]] void gather( const device T *src [[buffer(0)]], - const device Indices& indices [[buffer(1)]], + const constant Indices& indices [[buffer(1)]], device T *out [[buffer(2)]], - const device int *src_shape [[buffer(3)]], - const device size_t *src_strides [[buffer(4)]], - const device size_t& src_ndim [[buffer(5)]], - const device int *slice_sizes [[buffer(6)]], - const device size_t& slice_size [[buffer(7)]], - const device int *axes [[buffer(8)]], - uint gid [[thread_position_in_grid]]) { + const constant int *src_shape [[buffer(3)]], + const constant size_t *src_strides [[buffer(4)]], + const constant size_t& src_ndim [[buffer(5)]], + const constant int *slice_sizes [[buffer(6)]], + const constant int *axes [[buffer(7)]], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { - auto ind_idx = gid / slice_size; - auto ind_offset = gid % slice_size; + auto ind_idx = index.x; + auto ind_offset = index.y; size_t src_idx = 0; for (int i = 0; i < NIDX; ++i) { - auto idx_loc = elem_to_loc( - ind_idx, - &indices.shapes[indices.ndim * i], - &indices.strides[indices.ndim * i], - indices.ndim); + size_t idx_loc; + if (IDX_NDIM == 0) { + idx_loc = 0; + } else if (IDX_NDIM == 1) { + idx_loc = ind_idx * indices.strides[indices.ndim * i]; + } else { + idx_loc = elem_to_loc( + ind_idx, + &indices.shapes[indices.ndim * i], + &indices.strides[indices.ndim * i], + indices.ndim); + } auto ax = axes[i]; auto idx_val = offset_neg_idx( indices.buffers[i][idx_loc], src_shape[ax]); @@ -67,22 +76,49 @@ template auto src_offset = elem_to_loc( ind_offset, slice_sizes, src_strides, src_ndim); - out[gid] = src[src_idx + src_offset]; + + size_t out_idx = index.y + static_cast(grid_dim.y) * index.x; + out[out_idx] = src[src_offset + src_idx]; } #define instantiate_gather4(name, src_type, ind_type, nindex) \ -template [[host_name("gather" name "_" #nindex)]] \ -[[kernel]] void gather( \ +template [[host_name("gather" name "_" #nindex "_0")]] \ +[[kernel]] void gather( \ const device src_type *src [[buffer(0)]], \ - const device Indices& indices [[buffer(1)]], \ + const constant Indices& indices [[buffer(1)]], \ device src_type *out [[buffer(2)]], \ - const device int *src_shape [[buffer(3)]], \ - const device size_t *src_strides [[buffer(4)]], \ - const device size_t& src_ndim [[buffer(5)]], \ - const device int *slice_sizes [[buffer(6)]], \ - const device size_t& slice_size [[buffer(7)]], \ - const device int* axes [[buffer(8)]], \ - uint gid [[thread_position_in_grid]]); + const constant int *src_shape [[buffer(3)]], \ + const constant size_t *src_strides [[buffer(4)]], \ + const constant size_t& src_ndim [[buffer(5)]], \ + const constant int *slice_sizes [[buffer(6)]], \ + const constant int* axes [[buffer(7)]], \ + uint2 index [[thread_position_in_grid]], \ + uint2 grid_dim [[threads_per_grid]]); \ +template [[host_name("gather" name "_" #nindex "_1")]] \ +[[kernel]] void gather( \ + const device src_type *src [[buffer(0)]], \ + const constant Indices& indices [[buffer(1)]], \ + device src_type *out [[buffer(2)]], \ + const constant int *src_shape [[buffer(3)]], \ + const constant size_t *src_strides [[buffer(4)]], \ + const constant size_t& src_ndim [[buffer(5)]], \ + const constant int *slice_sizes [[buffer(6)]], \ + const constant int* axes [[buffer(7)]], \ + uint2 index [[thread_position_in_grid]], \ + uint2 grid_dim [[threads_per_grid]]); \ +template [[host_name("gather" name "_" #nindex)]] \ +[[kernel]] void gather( \ + const device src_type *src [[buffer(0)]], \ + const constant Indices& indices [[buffer(1)]], \ + device src_type *out [[buffer(2)]], \ + const constant int *src_shape [[buffer(3)]], \ + const constant size_t *src_strides [[buffer(4)]], \ + const constant size_t& src_ndim [[buffer(5)]], \ + const constant int *slice_sizes [[buffer(6)]], \ + const constant int* axes [[buffer(7)]], \ + uint2 index [[thread_position_in_grid]], \ + uint2 grid_dim [[threads_per_grid]]); + // Special for case NIDX=0 instantiate_gather4("bool_", bool, bool, 0)