From c9b41d460f123b1fb935f9dcf0b8a4cb5582a7f1 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 24 Oct 2024 11:05:46 -0700 Subject: [PATCH] Working 64-bit scans (#1506) --- mlx/backend/metal/jit/scan.h | 26 ----- mlx/backend/metal/jit_kernels.cpp | 31 +++-- mlx/backend/metal/kernels/scan.h | 164 +++++++++++++++++---------- mlx/backend/metal/kernels/scan.metal | 50 ++++---- mlx/backend/metal/kernels/utils.h | 60 ++++++++++ mlx/backend/metal/scan.cpp | 53 ++++++--- mlx/backend/metal/utils.cpp | 42 +++++++ mlx/backend/metal/utils.h | 8 ++ python/tests/test_ops.py | 12 ++ 9 files changed, 309 insertions(+), 137 deletions(-) delete mode 100644 mlx/backend/metal/jit/scan.h diff --git a/mlx/backend/metal/jit/scan.h b/mlx/backend/metal/jit/scan.h deleted file mode 100644 index e3e40cf9b..000000000 --- a/mlx/backend/metal/jit/scan.h +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright © 2024 Apple Inc. - -constexpr std::string_view scan_kernels = R"( -template [[host_name("contig_{0}")]] [[kernel]] void -contiguous_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>( - const device {1}* in [[buffer(0)]], - device {2}* out [[buffer(1)]], - const constant size_t& axis_size [[buffer(2)]], - uint gid [[thread_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint lsize [[threads_per_threadgroup]], - uint simd_size [[threads_per_simdgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]); - -template [[host_name("strided_{0}")]] [[kernel]] void -strided_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>( - const device {1}* in [[buffer(0)]], - device {2}* out [[buffer(1)]], - const constant size_t& axis_size [[buffer(2)]], - const constant size_t& stride [[buffer(3)]], - uint2 gid [[thread_position_in_grid]], - uint2 lid [[thread_position_in_threadgroup]], - uint2 lsize [[threads_per_threadgroup]], - uint simd_size [[threads_per_simdgroup]]); -)"; diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index ff567f0c0..ed1b4b1fd 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -4,7 +4,6 @@ #include "mlx/backend/metal/jit/arange.h" #include "mlx/backend/metal/jit/gemv_masked.h" #include "mlx/backend/metal/jit/includes.h" -#include "mlx/backend/metal/jit/scan.h" #include "mlx/backend/metal/jit/softmax.h" #include "mlx/backend/metal/jit/steel_conv.h" #include "mlx/backend/metal/jit/steel_gemm.h" @@ -224,18 +223,26 @@ MTL::ComputePipelineState* get_scan_kernel( const array& out) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { - std::string op_name = "Cum" + reduce_type; - op_name[3] = toupper(op_name[3]); + auto out_type = get_type_string(out.dtype()); + std::string op = "Cum" + reduce_type + "<" + out_type + ">"; + op[3] = toupper(op[3]); std::ostringstream kernel_source; - kernel_source << metal::utils() << metal::scan() - << fmt::format( - scan_kernels, - lib_name, - get_type_string(in.dtype()), - get_type_string(out.dtype()), - op_name, - inclusive, - reverse); + kernel_source << metal::utils() << metal::scan(); + const std::array, 2> scan_kernels = {{ + {"contig_", "contiguous_scan"}, + {"strided_", "strided_scan"}, + }}; + for (auto& [prefix, kernel] : scan_kernels) { + kernel_source << get_template_definition( + prefix + lib_name, + kernel, + get_type_string(in.dtype()), + get_type_string(out.dtype()), + op, + in.itemsize() <= 4 ? 4 : 2, + inclusive, + reverse); + } return kernel_source.str(); }); return d.get_kernel(kernel_name, lib); diff --git a/mlx/backend/metal/kernels/scan.h b/mlx/backend/metal/kernels/scan.h index 67b27ba89..cfa84c04c 100644 --- a/mlx/backend/metal/kernels/scan.h +++ b/mlx/backend/metal/kernels/scan.h @@ -1,7 +1,38 @@ // Copyright © 2023-2024 Apple Inc. +#pragma once + +#define DEFINE_SIMD_SCAN() \ + template = true> \ + T simd_scan(T val) { \ + return simd_scan_impl(val); \ + } \ + \ + template = true> \ + T simd_scan(T val) { \ + for (int i = 1; i <= 16; i *= 2) { \ + val = operator()(val, simd_shuffle_and_fill_up(val, init, i)); \ + } \ + return val; \ + } + +#define DEFINE_SIMD_EXCLUSIVE_SCAN() \ + template = true> \ + T simd_exclusive_scan(T val) { \ + return simd_exclusive_scan_impl(val); \ + } \ + \ + template = true> \ + T simd_exclusive_scan(T val) { \ + val = simd_scan(val); \ + return simd_shuffle_and_fill_up(val, init, 1); \ + } + template struct CumSum { + DEFINE_SIMD_SCAN() + DEFINE_SIMD_EXCLUSIVE_SCAN() + static constexpr constant U init = static_cast(0); template @@ -9,17 +40,20 @@ struct CumSum { return a + b; } - U simd_scan(U x) { + U simd_scan_impl(U x) { return simd_prefix_inclusive_sum(x); } - U simd_exclusive_scan(U x) { + U simd_exclusive_scan_impl(U x) { return simd_prefix_exclusive_sum(x); } }; template struct CumProd { + DEFINE_SIMD_SCAN() + DEFINE_SIMD_EXCLUSIVE_SCAN() + static constexpr constant U init = static_cast(1.0f); template @@ -27,11 +61,11 @@ struct CumProd { return a * b; } - U simd_scan(U x) { + U simd_scan_impl(U x) { return simd_prefix_inclusive_product(x); } - U simd_exclusive_scan(U x) { + U simd_exclusive_scan_impl(U x) { return simd_prefix_exclusive_product(x); } }; @@ -47,7 +81,7 @@ struct CumProd { bool simd_scan(bool x) { for (int i = 1; i <= 16; i *= 2) { - bool other = simd_shuffle_up(x, i); + bool other = simd_shuffle_and_fill_up(x, init, i); x &= other; } return x; @@ -70,7 +104,7 @@ struct CumMax { U simd_scan(U x) { for (int i = 1; i <= 16; i *= 2) { - U other = simd_shuffle_up(x, i); + U other = simd_shuffle_and_fill_up(x, init, i); x = (x >= other) ? x : other; } return x; @@ -93,7 +127,7 @@ struct CumMin { U simd_scan(U x) { for (int i = 1; i <= 16; i *= 2) { - U other = simd_shuffle_up(x, i); + U other = simd_shuffle_and_fill_up(x, init, i); x = (x <= other) ? x : other; } return x; @@ -178,20 +212,22 @@ template < const device T* in [[buffer(0)]], device U* out [[buffer(1)]], const constant size_t& axis_size [[buffer(2)]], - uint gid [[thread_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint lsize [[threads_per_threadgroup]], - uint simd_size [[threads_per_simdgroup]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + constexpr int simd_size = 32; Op op; // Position the pointers - in += (gid / lsize) * axis_size; - out += (gid / lsize) * axis_size; + size_t offset = (gid.y + gsize.y * size_t(gid.z)) * axis_size; + in += offset; + out += offset; // Compute the number of simd_groups - uint simd_groups = lsize / simd_size; + uint simd_groups = lsize.x / simd_size; // Allocate memory U prefix = Op::init; @@ -210,9 +246,9 @@ template < // value // Write block - for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) { + for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize.x); r++) { // Compute the block offset - uint offset = r * lsize * N_READS + lid * N_READS; + uint offset = r * lsize.x * N_READS + lid.x * N_READS; // Read the values if (reverse) { @@ -275,7 +311,7 @@ template < values, out + axis_size - offset - N_READS, offset, axis_size); } } else { - if (lid == 0 && offset == 0) { + if (lid.x == 0 && offset == 0) { out[axis_size - 1] = Op::init; } if ((offset + N_READS + 1) < axis_size) { @@ -298,7 +334,7 @@ template < values, out + offset, offset, axis_size); } } else { - if (lid == 0 && offset == 0) { + if (lid.x == 0 && offset == 0) { out[0] = Op::init; } if ((offset + N_READS + 1) < axis_size) { @@ -332,86 +368,98 @@ template < device U* out [[buffer(1)]], const constant size_t& axis_size [[buffer(2)]], const constant size_t& stride [[buffer(3)]], - uint2 gid [[threadgroup_position_in_grid]], - uint2 lid [[thread_position_in_threadgroup]], - uint2 lsize [[threads_per_threadgroup]], - uint simd_size [[threads_per_simdgroup]]) { + const constant size_t& stride_blocks [[buffer(4)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + constexpr int simd_size = 32; + constexpr int BM = 32; + constexpr int BN = 32; + constexpr int BN_pad = 32 + 16 / sizeof(U); + constexpr int n_simds = BN / N_READS; + constexpr int n_scans = BN / n_simds; Op op; - // Allocate memory - threadgroup U read_buffer[N_READS * 32 * 32 + N_READS * 32]; - U values[N_READS]; - U prefix[N_READS]; - for (int i = 0; i < N_READS; i++) { + threadgroup U read_buffer[BM * BN_pad]; + U values[n_scans]; + U prefix[n_scans]; + for (int i = 0; i < n_scans; i++) { prefix[i] = Op::init; } // Compute offsets - int offset = gid.y * axis_size * stride; - int global_index_x = gid.x * lsize.y * N_READS; + size_t full_gid = gid.y + gsize.y * size_t(gid.z); + size_t offset = full_gid / stride_blocks * axis_size * stride; + size_t global_index_x = full_gid % stride_blocks * BN; + uint read_offset_y = (lid.x * N_READS) / BN; + uint read_offset_x = (lid.x * N_READS) % BN; + uint scan_offset_y = simd_lane_id; + uint scan_offset_x = simd_group_id * n_scans; - for (uint j = 0; j < axis_size; j += simd_size) { + uint stride_limit = stride - global_index_x; + in += offset + global_index_x + read_offset_x; + out += offset + global_index_x + read_offset_x; + threadgroup U* read_into = + read_buffer + read_offset_y * BN_pad + read_offset_x; + threadgroup U* read_from = + read_buffer + scan_offset_y * BN_pad + scan_offset_x; + + for (uint j = 0; j < axis_size; j += BM) { // Calculate the indices for the current thread - uint index_y = j + lid.y; + uint index_y = j + read_offset_y; uint check_index_y = index_y; - uint index_x = global_index_x + lid.x * N_READS; if (reverse) { index_y = axis_size - 1 - index_y; } // Read in SM - if (check_index_y < axis_size && (index_x + N_READS) < stride) { + if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { for (int i = 0; i < N_READS; i++) { - read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = - in[offset + index_y * stride + index_x + i]; + read_into[i] = in[index_y * stride + i]; } } else { for (int i = 0; i < N_READS; i++) { - if (check_index_y < axis_size && (index_x + i) < stride) { - read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = - in[offset + index_y * stride + index_x + i]; + if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { + read_into[i] = in[index_y * stride + i]; } else { - read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = - Op::init; + read_into[i] = Op::init; } } } threadgroup_barrier(mem_flags::mem_threadgroup); // Read strided into registers - for (int i = 0; i < N_READS; i++) { - values[i] = - read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i]; + for (int i = 0; i < n_scans; i++) { + values[i] = read_from[i]; } - // Do we need the following barrier? Shouldn't all simd threads execute - // simultaneously? simdgroup_barrier(mem_flags::mem_threadgroup); // Perform the scan - for (int i = 0; i < N_READS; i++) { + for (int i = 0; i < n_scans; i++) { values[i] = op.simd_scan(values[i]); values[i] = op(values[i], prefix[i]); prefix[i] = simd_shuffle(values[i], simd_size - 1); } // Write to SM - for (int i = 0; i < N_READS; i++) { - read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] = - values[i]; + for (int i = 0; i < n_scans; i++) { + read_from[i] = values[i]; } threadgroup_barrier(mem_flags::mem_threadgroup); // Write to device memory if (!inclusive) { if (check_index_y == 0) { - if ((index_x + N_READS) < stride) { + if ((read_offset_x + N_READS) < stride_limit) { for (int i = 0; i < N_READS; i++) { - out[offset + index_y * stride + index_x + i] = Op::init; + out[index_y * stride + i] = Op::init; } } else { for (int i = 0; i < N_READS; i++) { - if ((index_x + i) < stride) { - out[offset + index_y * stride + index_x + i] = Op::init; + if ((read_offset_x + i) < stride_limit) { + out[index_y * stride + i] = Op::init; } } } @@ -424,16 +472,14 @@ template < check_index_y += 1; } } - if (check_index_y < axis_size && (index_x + N_READS) < stride) { + if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { for (int i = 0; i < N_READS; i++) { - out[offset + index_y * stride + index_x + i] = - read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i]; + out[index_y * stride + i] = read_into[i]; } } else { for (int i = 0; i < N_READS; i++) { - if (check_index_y < axis_size && (index_x + i) < stride) { - out[offset + index_y * stride + index_x + i] = - read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i]; + if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { + out[index_y * stride + i] = read_into[i]; } } } diff --git a/mlx/backend/metal/kernels/scan.metal b/mlx/backend/metal/kernels/scan.metal index 2f026bc36..6aa36f5a3 100644 --- a/mlx/backend/metal/kernels/scan.metal +++ b/mlx/backend/metal/kernels/scan.metal @@ -13,15 +13,15 @@ using namespace metal; #define instantiate_contiguous_scan( \ name, itype, otype, op, inclusive, reverse, nreads) \ - template [[host_name("contig_scan_" #name)]] [[kernel]] void \ + template [[host_name("contig_scan_" #name)]] [[kernel]] void \ contiguous_scan, nreads, inclusive, reverse>( \ const device itype* in [[buffer(0)]], \ device otype* out [[buffer(1)]], \ const constant size_t& axis_size [[buffer(2)]], \ - uint gid [[thread_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint lsize [[threads_per_threadgroup]], \ - uint simd_size [[threads_per_simdgroup]], \ + uint3 gid [[threadgroup_position_in_grid]], \ + uint3 gsize [[threadgroups_per_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 lsize [[threads_per_threadgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]]); @@ -33,10 +33,12 @@ using namespace metal; device otype* out [[buffer(1)]], \ const constant size_t& axis_size [[buffer(2)]], \ const constant size_t& stride [[buffer(3)]], \ - uint2 gid [[thread_position_in_grid]], \ - uint2 lid [[thread_position_in_threadgroup]], \ - uint2 lsize [[threads_per_threadgroup]], \ - uint simd_size [[threads_per_simdgroup]]); + const constant size_t& stride_blocks [[buffer(4)]], \ + uint3 gid [[threadgroup_position_in_grid]], \ + uint3 gsize [[threadgroups_per_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); #define instantiate_scan_helper(name, itype, otype, op, nreads) \ instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \ @@ -52,51 +54,51 @@ instantiate_scan_helper(sum_bool__int32, bool, int32_t, CumSu instantiate_scan_helper(sum_uint8_uint8, uint8_t, uint8_t, CumSum, 4) instantiate_scan_helper(sum_uint16_uint16, uint16_t, uint16_t, CumSum, 4) instantiate_scan_helper(sum_uint32_uint32, uint32_t, uint32_t, CumSum, 4) -//instantiate_scan_helper(sum_uint64_uint64, uint64_t, uint64_t, CumSum, 2) +instantiate_scan_helper(sum_uint64_uint64, uint64_t, uint64_t, CumSum, 2) instantiate_scan_helper(sum_int8_int8, int8_t, int8_t, CumSum, 4) instantiate_scan_helper(sum_int16_int16, int16_t, int16_t, CumSum, 4) instantiate_scan_helper(sum_int32_int32, int32_t, int32_t, CumSum, 4) -//instantiate_scan_helper(sum_int64_int64, int64_t, int64_t, CumSum, 2) +instantiate_scan_helper(sum_int64_int64, int64_t, int64_t, CumSum, 2) instantiate_scan_helper(sum_float16_float16, half, half, CumSum, 4) instantiate_scan_helper(sum_float32_float32, float, float, CumSum, 4) instantiate_scan_helper(sum_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumSum, 4) -//instantiate_scan_helper(sum_complex64_complex64, complex64_t, complex64_t, CumSum) -//instantiate_scan_helper(prod_bool__bool_, bool, bool, CumProd, 4) +instantiate_scan_helper(sum_complex64_complex64, complex64_t, complex64_t, CumSum, 2) +instantiate_scan_helper(prod_bool__bool_, bool, bool, CumProd, 4) instantiate_scan_helper(prod_uint8_uint8, uint8_t, uint8_t, CumProd, 4) instantiate_scan_helper(prod_uint16_uint16, uint16_t, uint16_t, CumProd, 4) instantiate_scan_helper(prod_uint32_uint32, uint32_t, uint32_t, CumProd, 4) -//instantiate_scan_helper(prod_uint64_uint64, uint64_t, uint64_t, CumProd, 2) +instantiate_scan_helper(prod_uint64_uint64, uint64_t, uint64_t, CumProd, 2) instantiate_scan_helper(prod_int8_int8, int8_t, int8_t, CumProd, 4) instantiate_scan_helper(prod_int16_int16, int16_t, int16_t, CumProd, 4) instantiate_scan_helper(prod_int32_int32, int32_t, int32_t, CumProd, 4) -//instantiate_scan_helper(prod_int64_int64, int64_t, int64_t, CumProd, 2) +instantiate_scan_helper(prod_int64_int64, int64_t, int64_t, CumProd, 2) instantiate_scan_helper(prod_float16_float16, half, half, CumProd, 4) instantiate_scan_helper(prod_float32_float32, float, float, CumProd, 4) instantiate_scan_helper(prod_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumProd, 4) -//instantiate_scan_helper(prod_complex64_complex64, complex64_t, complex64_t, CumProd) -//instantiate_scan_helper(max_bool__bool_, bool, bool, CumMax, 4) +instantiate_scan_helper(prod_complex64_complex64, complex64_t, complex64_t, CumProd, 2) +instantiate_scan_helper(max_bool__bool_, bool, bool, CumMax, 4) instantiate_scan_helper(max_uint8_uint8, uint8_t, uint8_t, CumMax, 4) instantiate_scan_helper(max_uint16_uint16, uint16_t, uint16_t, CumMax, 4) instantiate_scan_helper(max_uint32_uint32, uint32_t, uint32_t, CumMax, 4) -//instantiate_scan_helper(max_uint64_uint64, uint64_t, uint64_t, CumMax, 2) +instantiate_scan_helper(max_uint64_uint64, uint64_t, uint64_t, CumMax, 2) instantiate_scan_helper(max_int8_int8, int8_t, int8_t, CumMax, 4) instantiate_scan_helper(max_int16_int16, int16_t, int16_t, CumMax, 4) instantiate_scan_helper(max_int32_int32, int32_t, int32_t, CumMax, 4) -//instantiate_scan_helper(max_int64_int64, int64_t, int64_t, CumMax, 2) +instantiate_scan_helper(max_int64_int64, int64_t, int64_t, CumMax, 2) instantiate_scan_helper(max_float16_float16, half, half, CumMax, 4) instantiate_scan_helper(max_float32_float32, float, float, CumMax, 4) instantiate_scan_helper(max_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMax, 4) -//instantiate_scan_helper(max_complex64_complex64, complex64_t, complex64_t, CumMax) -//instantiate_scan_helper(min_bool__bool_, bool, bool, CumMin, 4) +instantiate_scan_helper(max_complex64_complex64, complex64_t, complex64_t, CumMax, 2) +instantiate_scan_helper(min_bool__bool_, bool, bool, CumMin, 4) instantiate_scan_helper(min_uint8_uint8, uint8_t, uint8_t, CumMin, 4) instantiate_scan_helper(min_uint16_uint16, uint16_t, uint16_t, CumMin, 4) instantiate_scan_helper(min_uint32_uint32, uint32_t, uint32_t, CumMin, 4) -//instantiate_scan_helper(min_uint64_uint64, uint64_t, uint64_t, CumMin, 2) +instantiate_scan_helper(min_uint64_uint64, uint64_t, uint64_t, CumMin, 2) instantiate_scan_helper(min_int8_int8, int8_t, int8_t, CumMin, 4) instantiate_scan_helper(min_int16_int16, int16_t, int16_t, CumMin, 4) instantiate_scan_helper(min_int32_int32, int32_t, int32_t, CumMin, 4) -//instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMin, 2) +instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMin, 2) instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4) instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4) instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4) -//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin) // clang-format on +instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2) // clang-format on diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index 721c094ca..33bf8fdae 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -320,3 +320,63 @@ inline complex64_t simd_shuffle_down(complex64_t data, uint16_t delta) { return complex64_t( simd_shuffle_down(data.real, delta), simd_shuffle_down(data.imag, delta)); } + +inline uint64_t simd_shuffle_up(uint64_t data, uint16_t delta) { + return as_type(metal::simd_shuffle_up(as_type(data), delta)); +} + +inline int64_t simd_shuffle_up(int64_t data, uint16_t delta) { + return as_type(metal::simd_shuffle_up(as_type(data), delta)); +} + +inline bool simd_shuffle_up(bool data, uint16_t delta) { + return simd_shuffle_up(static_cast(data), delta); +} + +inline complex64_t simd_shuffle_up(complex64_t data, uint16_t delta) { + return complex64_t( + simd_shuffle_up(data.real, delta), simd_shuffle_up(data.imag, delta)); +} + +inline uint64_t +simd_shuffle_and_fill_up(uint64_t data, uint64_t filling, uint16_t delta) { + return as_type(metal::simd_shuffle_and_fill_up( + as_type(data), as_type(filling), delta)); +} + +inline int64_t +simd_shuffle_and_fill_up(int64_t data, int64_t filling, uint16_t delta) { + return as_type(metal::simd_shuffle_and_fill_up( + as_type(data), as_type(filling), delta)); +} + +inline bool simd_shuffle_and_fill_up(bool data, bool filling, uint16_t delta) { + return simd_shuffle_and_fill_up( + static_cast(data), static_cast(filling), delta); +} + +inline complex64_t simd_shuffle_and_fill_up( + complex64_t data, + complex64_t filling, + uint16_t delta) { + return complex64_t( + simd_shuffle_and_fill_up(data.real, filling.real, delta), + simd_shuffle_and_fill_up(data.imag, filling.imag, delta)); +} + +inline uint64_t simd_shuffle(uint64_t data, uint16_t lane) { + return as_type(metal::simd_shuffle(as_type(data), lane)); +} + +inline int64_t simd_shuffle(int64_t data, uint16_t lane) { + return as_type(metal::simd_shuffle(as_type(data), lane)); +} + +inline bool simd_shuffle(bool data, uint16_t lane) { + return simd_shuffle(static_cast(data), lane); +} + +inline complex64_t simd_shuffle(complex64_t data, uint16_t lane) { + return complex64_t( + simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane)); +} diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp index 8737b25e4..ae9c6a66f 100644 --- a/mlx/backend/metal/scan.cpp +++ b/mlx/backend/metal/scan.cpp @@ -14,19 +14,27 @@ namespace mlx::core { void Scan::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); - out.set_data(allocator::malloc_or_wait(out.nbytes())); - auto& s = stream(); auto& d = metal::device(s.device); - // Ensure contiguity std::vector copies; auto in = inputs[0]; - if (!in.flags().row_contiguous) { + if (in.flags().contiguous && in.strides()[axis_] != 0) { + if (in.is_donatable() && in.itemsize() == out.itemsize()) { + out.move_shared_buffer(in); + } else { + out.set_data( + allocator::malloc_or_wait(in.data_size() * out.itemsize()), + in.data_size(), + in.strides(), + in.flags()); + } + } else { array arr_copy(in.shape(), in.dtype(), nullptr, {}); copy_gpu(in, arr_copy, CopyType::General, s); copies.push_back(arr_copy); in = arr_copy; + out.move_shared_buffer(in); } bool contiguous = in.strides()[axis_] == 1; @@ -61,7 +69,8 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { if (contiguous) { auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); - compute_encoder.set_input_array(in, 0); + compute_encoder.set_input_array( + in.data_shared_ptr() == nullptr ? out : in, 0); compute_encoder.set_output_array(out, 1); size_t size = in.shape(axis_); compute_encoder->setBytes(&size, sizeof(size_t), 2); @@ -70,7 +79,6 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { int n_reads = (in.itemsize() <= 4) ? 4 : 2; constexpr int simd_size = 32; int elements_per_simd = n_reads * simd_size; - int thread_groups = in.size() / size; int thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); if (size <= n_reads * 1024) { thread_group_size = @@ -82,28 +90,41 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { thread_group_size = std::min( thread_group_size, static_cast(kernel->maxTotalThreadsPerThreadgroup())); - MTL::Size grid_dims = MTL::Size(thread_groups * thread_group_size, 1, 1); - MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); + auto tmp_grid_dims = + get_2d_grid_dims(in.shape(), in.strides(), /** divisor= */ size); + MTL::Size grid_dims( + thread_group_size, tmp_grid_dims.width, tmp_grid_dims.height); + MTL::Size group_dims(thread_group_size, 1, 1); compute_encoder.dispatchThreads(grid_dims, group_dims); } else { auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); - compute_encoder.set_input_array(in, 0); + compute_encoder.set_input_array( + in.data_shared_ptr() == nullptr ? out : in, 0); compute_encoder.set_output_array(out, 1); size_t size = in.shape(axis_); size_t stride = in.strides()[axis_]; + int bm = 32; + int bn = 32; + size_t stride_blocks = (stride + bn - 1) / bn; compute_encoder->setBytes(&size, sizeof(size_t), 2); compute_encoder->setBytes(&stride, sizeof(size_t), 3); + compute_encoder->setBytes(&stride_blocks, sizeof(size_t), 4); // Compute the thread grid int n_reads = (in.itemsize() <= 4) ? 4 : 2; - int tile_x = 32; - int tile_y = 32; - int elements_per_tile_x = tile_x * n_reads; - int grid_y = in.size() / size / stride; - int grid_x = (stride + elements_per_tile_x - 1) / elements_per_tile_x; - MTL::Size grid_dims = MTL::Size(grid_x * tile_x, grid_y * tile_y, 1); - MTL::Size group_dims = MTL::Size(tile_x, tile_y, 1); + int n_simdgroups = bn / n_reads; + int thread_group_size = n_simdgroups * 32; + auto tmp_grid_dims = get_2d_grid_dims( + in.shape(), in.strides(), /** divisor= */ size * stride); + if (tmp_grid_dims.width * stride_blocks <= UINT_MAX) { + tmp_grid_dims.width *= stride_blocks; + } else { + tmp_grid_dims.height *= stride_blocks; + } + MTL::Size grid_dims( + thread_group_size, tmp_grid_dims.width, tmp_grid_dims.height); + MTL::Size group_dims(thread_group_size, 1, 1); compute_encoder.dispatchThreads(grid_dims, group_dims); } diff --git a/mlx/backend/metal/utils.cpp b/mlx/backend/metal/utils.cpp index 1242209a1..d15e221dd 100644 --- a/mlx/backend/metal/utils.cpp +++ b/mlx/backend/metal/utils.cpp @@ -107,6 +107,48 @@ MTL::Size get_2d_grid_dims( static_cast(grid_x), static_cast(grid_y), 1); } +MTL::Size get_2d_grid_dims( + const std::vector& shape, + const std::vector& strides, + size_t divisor) { + // Compute the 2d grid dimensions such that the total size of the grid is + // divided by divisor. + size_t grid_x = 1; + size_t grid_y = 1; + for (int i = 0; i < shape.size(); ++i) { + if (strides[i] == 0) { + continue; + } + + // No need to add this shape we can just remove it from the divisor. + if (divisor % shape[i] == 0) { + divisor /= shape[i]; + continue; + } + + if (grid_x * shape[i] < UINT32_MAX) { + grid_x *= shape[i]; + } else { + grid_y *= shape[i]; + } + + if (divisor > 1) { + if (grid_x % divisor == 0) { + grid_x /= divisor; + divisor = 1; + } else if (grid_y % divisor == 0) { + grid_y /= divisor; + divisor = 1; + } + } + } + if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) { + throw std::runtime_error("Unable to safely factor shape."); + } + return MTL::Size( + static_cast(grid_x), static_cast(grid_y), 1); +} + std::string get_primitive_string(Primitive* primitive) { std::ostringstream op_t; primitive->print(op_t); diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index ad49c52a1..509a7e651 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -42,6 +42,14 @@ MTL::Size get_2d_grid_dims( const std::vector& shape, const std::vector& strides); +// Same as above but we do an implicit division with divisor. +// Basically, equivalent to factorizing +// Prod(s \forall s in shape if strides[s] > 0) / divisor. +MTL::Size get_2d_grid_dims( + const std::vector& shape, + const std::vector& strides, + size_t divisor); + inline NS::String* make_string(std::ostringstream& os) { std::string string = os.str(); return NS::String::string(string.c_str(), NS::UTF8StringEncoding); diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 34b2d66bf..0bb34cd87 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1758,6 +1758,18 @@ class TestOps(mlx_tests.MLXTestCase): c_mlx = mxop(a_mlx, axis=axis) self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3)) + a_mlx = mx.random.randint(shape=(32, 32, 32), low=-100, high=100) + for dt in [mx.int32, mx.int64]: + mxx = a_mlx.astype(dt) + npx = np.array(mxx) + for op in ["cumsum", "cumprod"]: + npop = getattr(np, op) + mxop = getattr(mx, op) + for axis in (None, 0, 1, 2): + c_npy = npop(npx, axis=axis, dtype=npx.dtype) + c_mlx = mxop(mxx, axis=axis) + self.assertTrue(np.array_equal(c_npy, c_mlx)) + a_mlx = mx.random.randint(shape=(32, 32, 32), low=-100, high=100) for op in ["cumsum", "cumprod", "cummax", "cummin"]: mxop = getattr(mx, op)