diff --git a/docs/src/install.rst b/docs/src/install.rst index 7b99c3145..bdad740ea 100644 --- a/docs/src/install.rst +++ b/docs/src/install.rst @@ -200,7 +200,7 @@ GGUF, you can do: -DBUILD_SHARED_LIBS=ON \ -DMLX_BUILD_CPU=OFF \ -DMLX_BUILD_SAFETENSORS=OFF \ - -DMLX_BUILD_GGUF=OFF + -DMLX_BUILD_GGUF=OFF \ -DMLX_METAL_JIT=ON THE `MLX_METAL_JIT` flag minimizes the size of the MLX Metal library which diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 46dceb788..6d46b03c7 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -35,6 +35,7 @@ make_jit_source( utils kernels/bf16.h kernels/complex.h + kernels/defines.h ) make_jit_source( unary_ops @@ -44,7 +45,7 @@ make_jit_source( make_jit_source(binary_ops) make_jit_source(ternary_ops) make_jit_source( - reduction + reduce_utils kernels/atomic.h kernels/reduction/ops.h ) @@ -57,11 +58,21 @@ if (MLX_METAL_JIT) PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp ) + make_jit_source(arange) make_jit_source(copy) make_jit_source(unary) make_jit_source(binary) make_jit_source(binary_two) make_jit_source(ternary) + make_jit_source(softmax) + make_jit_source(scan) + make_jit_source(sort) + make_jit_source( + reduce + kernels/reduction/reduce_all.h + kernels/reduction/reduce_col.h + kernels/reduction/reduce_row.h + ) else() target_sources( mlx diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index 0fec4eb70..b1151d9a4 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -229,7 +229,8 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { auto lib = d.get_library(lib_name); if (lib == nullptr) { std::ostringstream kernel_source; - kernel_source << metal::utils() << metal::reduction() << metal::scatter(); + kernel_source << metal::utils() << metal::reduce_utils() + << metal::scatter(); std::string out_type_str = get_type_string(out.dtype()); std::string idx_type_str = diff --git a/mlx/backend/metal/jit/arange.h b/mlx/backend/metal/jit/arange.h new file mode 100644 index 000000000..0c224dca4 --- /dev/null +++ b/mlx/backend/metal/jit/arange.h @@ -0,0 +1,9 @@ +// Copyright © 2024 Apple Inc. + +constexpr std::string_view arange_kernels = R"( +template [[host_name("{0}")]] [[kernel]] void arange<{1}>( + constant const {1}& start, + constant const {1}& step, + device {1}* out, + uint index [[thread_position_in_grid]]); +)"; diff --git a/mlx/backend/metal/jit/includes.h b/mlx/backend/metal/jit/includes.h index e98462374..7bfcbbedd 100644 --- a/mlx/backend/metal/jit/includes.h +++ b/mlx/backend/metal/jit/includes.h @@ -1,4 +1,4 @@ -// Copyright © 2023-24 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #pragma once @@ -8,14 +8,19 @@ const char* utils(); const char* binary_ops(); const char* unary_ops(); const char* ternary_ops(); -const char* reduction(); +const char* reduce_utils(); const char* gather(); const char* scatter(); +const char* arange(); const char* unary(); const char* binary(); const char* binary_two(); const char* copy(); const char* ternary(); +const char* scan(); +const char* softmax(); +const char* sort(); +const char* reduce(); } // namespace mlx::core::metal diff --git a/mlx/backend/metal/jit/reduce.h b/mlx/backend/metal/jit/reduce.h new file mode 100644 index 000000000..ea2f5f4e0 --- /dev/null +++ b/mlx/backend/metal/jit/reduce.h @@ -0,0 +1,168 @@ +// Copyright © 2024 Apple Inc. + +constexpr std::string_view reduce_init_kernels = R"( +[[kernel]] void {0}( + device {1}* out [[buffer(0)]], + uint tid [[thread_position_in_grid]]) {{ + out[tid] = {2}<{1}>::init; +}} +)"; + +constexpr std::string_view reduce_kernels = R"( +template [[host_name("all_{0}")]] [[kernel]] void +all_reduce<{1}, {2}, {3}<{2}>>( + const device {1}* in [[buffer(0)]], + device mlx_atomic<{2}>* out [[buffer(1)]], + const device size_t& in_size [[buffer(2)]], + uint gid [[thread_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint grid_size [[threads_per_grid]], + uint simd_per_group [[simdgroups_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]); +template [[host_name("colGeneral_{0}")]] [[kernel]] void +col_reduce_general<{1}, {2}, {3}<{2}>>( + const device {1}* in [[buffer(0)]], + device mlx_atomic<{2}>* out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& reduction_stride [[buffer(3)]], + const constant size_t& out_size [[buffer(4)]], + const constant int* shape [[buffer(5)]], + const constant size_t* strides [[buffer(6)]], + const constant int& ndim [[buffer(7)]], + threadgroup {2}* local_data [[threadgroup(0)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]]); +template [[host_name("colSmall_{0}")]] [[kernel]] void +col_reduce_small<{1}, {2}, {3}<{2}>>( + const device {1}* in [[buffer(0)]], + device {2}* out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& reduction_stride [[buffer(3)]], + const constant size_t& out_size [[buffer(4)]], + const constant int* shape [[buffer(5)]], + const constant size_t* strides [[buffer(6)]], + const constant int& ndim [[buffer(7)]], + const constant size_t& non_col_reductions [[buffer(8)]], + const constant int* non_col_shapes [[buffer(9)]], + const constant size_t* non_col_strides [[buffer(10)]], + const constant int& non_col_ndim [[buffer(11)]], + uint tid [[thread_position_in_grid]]); +template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void +row_reduce_general_small<{1}, {2}, {3}<{2}>>( + const device {1}* in [[buffer(0)]], + device {2}* out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& out_size [[buffer(3)]], + const constant size_t& non_row_reductions [[buffer(4)]], + const constant int* shape [[buffer(5)]], + const constant size_t* strides [[buffer(6)]], + const constant int& ndim [[buffer(7)]], + uint lid [[thread_position_in_grid]]); +template [[host_name("rowGeneralMed_{0}")]] [[kernel]] void +row_reduce_general_med<{1}, {2}, {3}<{2}>>( + const device {1}* in [[buffer(0)]], + device {2}* out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& out_size [[buffer(3)]], + const constant size_t& non_row_reductions [[buffer(4)]], + const constant int* shape [[buffer(5)]], + const constant size_t* strides [[buffer(6)]], + const constant int& ndim [[buffer(7)]], + uint tid [[threadgroup_position_in_grid]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]); +template [[host_name("rowGeneral_{0}")]] [[kernel]] void +row_reduce_general<{1}, {2}, {3}<{2}>>( + const device {1}* in [[buffer(0)]], + device mlx_atomic<{2}>* out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& out_size [[buffer(3)]], + const constant size_t& non_row_reductions [[buffer(4)]], + const constant int* shape [[buffer(5)]], + const constant size_t* strides [[buffer(6)]], + const constant int& ndim [[buffer(7)]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_per_group [[simdgroups_per_threadgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]); +)"; + +constexpr std::string_view reduce_non_atomic_kernels = R"( +template [[host_name("allNoAtomics_{0}")]] [[kernel]] void +all_reduce_no_atomics<{1}, {2}, {3}<{2}>>( + const device {1}* in [[buffer(0)]], + device {2}* out [[buffer(1)]], + const device size_t& in_size [[buffer(2)]], + uint gid [[thread_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint grid_size [[threads_per_grid]], + uint simd_per_group [[simdgroups_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint thread_group_id [[threadgroup_position_in_grid]]); + +template [[host_name("colGeneralNoAtomics_{0}")]] [[kernel]] void + col_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>( + const device {1}* in [[buffer(0)]], + device {2}* out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& reduction_stride [[buffer(3)]], + const constant size_t& out_size [[buffer(4)]], + const constant int* shape [[buffer(5)]], + const constant size_t* strides [[buffer(6)]], + const constant int& ndim [[buffer(7)]], + threadgroup {2}* local_data [[threadgroup(0)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 gid [[thread_position_in_grid]], + uint3 lsize [[threads_per_threadgroup]], + uint3 gsize [[threads_per_grid]]); +template [[host_name("colSmall_{0}")]] [[kernel]] void +col_reduce_small<{1}, {2}, {3}<{2}>>( + const device {1}* in [[buffer(0)]], + device {2}* out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& reduction_stride [[buffer(3)]], + const constant size_t& out_size [[buffer(4)]], + const constant int* shape [[buffer(5)]], + const constant size_t* strides [[buffer(6)]], + const constant int& ndim [[buffer(7)]], + const constant size_t& non_col_reductions [[buffer(8)]], + const constant int* non_col_shapes [[buffer(9)]], + const constant size_t* non_col_strides [[buffer(10)]], + const constant int& non_col_ndim [[buffer(11)]], + uint tid [[thread_position_in_grid]]); +template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void +row_reduce_general_small<{1}, {2}, {3}<{2}>>( + const device {1}* in [[buffer(0)]], + device {2}* out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& out_size [[buffer(3)]], + const constant size_t& non_row_reductions [[buffer(4)]], + const constant int* shape [[buffer(5)]], + const constant size_t* strides [[buffer(6)]], + const constant int& ndim [[buffer(7)]], + uint lid [[thread_position_in_grid]]); +template [[host_name("rowGeneralNoAtomics_{0}")]] [[kernel]] void +row_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>( + const device {1}* in [[buffer(0)]], + device {2}* out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& out_size [[buffer(3)]], + const constant size_t& non_row_reductions [[buffer(4)]], + const constant int* shape [[buffer(5)]], + const constant size_t* strides [[buffer(6)]], + const constant int& ndim [[buffer(7)]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], + uint3 gsize [[threads_per_grid]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_per_group [[simdgroups_per_threadgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]); +)"; diff --git a/mlx/backend/metal/jit/scan.h b/mlx/backend/metal/jit/scan.h new file mode 100644 index 000000000..e3e40cf9b --- /dev/null +++ b/mlx/backend/metal/jit/scan.h @@ -0,0 +1,26 @@ +// Copyright © 2024 Apple Inc. + +constexpr std::string_view scan_kernels = R"( +template [[host_name("contig_{0}")]] [[kernel]] void +contiguous_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>( + const device {1}* in [[buffer(0)]], + device {2}* out [[buffer(1)]], + const constant size_t& axis_size [[buffer(2)]], + uint gid [[thread_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint lsize [[threads_per_threadgroup]], + uint simd_size [[threads_per_simdgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]); + +template [[host_name("strided_{0}")]] [[kernel]] void +strided_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>( + const device {1}* in [[buffer(0)]], + device {2}* out [[buffer(1)]], + const constant size_t& axis_size [[buffer(2)]], + const constant size_t& stride [[buffer(3)]], + uint2 gid [[thread_position_in_grid]], + uint2 lid [[thread_position_in_threadgroup]], + uint2 lsize [[threads_per_threadgroup]], + uint simd_size [[threads_per_simdgroup]]); +)"; diff --git a/mlx/backend/metal/jit/softmax.h b/mlx/backend/metal/jit/softmax.h new file mode 100644 index 000000000..a9672a050 --- /dev/null +++ b/mlx/backend/metal/jit/softmax.h @@ -0,0 +1,23 @@ +// Copyright © 2024 Apple Inc. + +constexpr std::string_view softmax_kernels = R"( +template [[host_name("block_{0}")]] [[kernel]] void +softmax_single_row<{1}, {2}>( + const device {1}* in, + device {1}* out, + constant int& axis_size, + uint gid [[thread_position_in_grid]], + uint _lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]); +template [[host_name("looped_{0}")]] [[kernel]] void +softmax_looped<{1}, {2}>( + const device {1}* in, + device {1}* out, + constant int& axis_size, + uint gid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]); +)"; diff --git a/mlx/backend/metal/jit/sort.h b/mlx/backend/metal/jit/sort.h new file mode 100644 index 000000000..023822f48 --- /dev/null +++ b/mlx/backend/metal/jit/sort.h @@ -0,0 +1,81 @@ +// Copyright © 2024 Apple Inc. + +constexpr std::string_view block_sort_kernels = R"( +template [[host_name("carg_{0}")]] [[kernel]] void +block_sort<{1}, {2}, true, {3}, {4}>( + const device {1}* inp [[buffer(0)]], + device {2}* out [[buffer(1)]], + const constant int& size_sorted_axis [[buffer(2)]], + const constant int& stride_sorted_axis [[buffer(3)]], + const constant int& stride_segment_axis [[buffer(4)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]); +template [[host_name("ncarg_{0}")]] [[kernel]] void +block_sort_nc<{1}, {2}, true, {3}, {4}>( + const device {1}* inp [[buffer(0)]], + device {2}* out [[buffer(1)]], + const constant int& size_sorted_axis [[buffer(2)]], + const constant int& stride_sorted_axis [[buffer(3)]], + const constant int& nc_dim [[buffer(4)]], + const device int* nc_shape [[buffer(5)]], + const device size_t* nc_strides [[buffer(6)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]); +template [[host_name("c_{0}")]] [[kernel]] void +block_sort<{1}, {2}, false, {3}, {4}>( + const device {1}* inp [[buffer(0)]], + device {2}* out [[buffer(1)]], + const constant int& size_sorted_axis [[buffer(2)]], + const constant int& stride_sorted_axis [[buffer(3)]], + const constant int& stride_segment_axis [[buffer(4)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]); +template [[host_name("nc_{0}")]] [[kernel]] void +block_sort_nc<{1}, {2}, false, {3}, {4}>( + const device {1}* inp [[buffer(0)]], + device {2}* out [[buffer(1)]], + const constant int& size_sorted_axis [[buffer(2)]], + const constant int& stride_sorted_axis [[buffer(3)]], + const constant int& nc_dim [[buffer(4)]], + const device int* nc_shape [[buffer(5)]], + const device size_t* nc_strides [[buffer(6)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]); +)"; + +constexpr std::string_view multiblock_sort_kernels = R"( +template [[host_name("sort_{0}")]] [[kernel]] void +mb_block_sort<{1}, {2}, true, {3}, {4}>( + const device {1}* inp [[buffer(0)]], + device {1}* out_vals [[buffer(1)]], + device {2}* out_idxs [[buffer(2)]], + const constant int& size_sorted_axis [[buffer(3)]], + const constant int& stride_sorted_axis [[buffer(4)]], + const constant int& nc_dim [[buffer(5)]], + const device int* nc_shape [[buffer(6)]], + const device size_t* nc_strides [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]); +template [[host_name("partition_{0}")]] [[kernel]] void +mb_block_partition<{1}, {2}, true, {3}, {4}>( + device {2}* block_partitions [[buffer(0)]], + const device {1}* dev_vals [[buffer(1)]], + const device {2}* dev_idxs [[buffer(2)]], + const constant int& size_sorted_axis [[buffer(3)]], + const constant int& merge_tiles [[buffer(4)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 tgp_dims [[threads_per_threadgroup]]); +template [[host_name("merge_{0}")]] [[kernel]] void +mb_block_merge<{1}, {2}, true, {3}, {4}>( + const device {2}* block_partitions [[buffer(0)]], + const device {1}* dev_vals_in [[buffer(1)]], + const device {2}* dev_idxs_in [[buffer(2)]], + device {1}* dev_vals_out [[buffer(3)]], + device {2}* dev_idxs_out [[buffer(4)]], + const constant int& size_sorted_axis [[buffer(5)]], + const constant int& merge_tiles [[buffer(6)]], + const constant int& num_tiles [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]); +)"; diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 60c9c7d2d..5d46adbc5 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -3,10 +3,15 @@ #include #include "mlx/backend/common/compiled.h" +#include "mlx/backend/metal/jit/arange.h" #include "mlx/backend/metal/jit/binary.h" #include "mlx/backend/metal/jit/binary_two.h" #include "mlx/backend/metal/jit/copy.h" #include "mlx/backend/metal/jit/includes.h" +#include "mlx/backend/metal/jit/reduce.h" +#include "mlx/backend/metal/jit/scan.h" +#include "mlx/backend/metal/jit/softmax.h" +#include "mlx/backend/metal/jit/sort.h" #include "mlx/backend/metal/jit/ternary.h" #include "mlx/backend/metal/jit/unary.h" #include "mlx/backend/metal/kernels.h" @@ -20,6 +25,22 @@ std::string op_name(const array& arr) { return op_t.str(); } +MTL::ComputePipelineState* get_arange_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out) { + const auto& lib_name = kernel_name; + auto lib = d.get_library(lib_name); + if (lib == nullptr) { + std::ostringstream kernel_source; + kernel_source + << metal::utils() << metal::arange() + << fmt::format(arange_kernels, lib_name, get_type_string(out.dtype())); + lib = d.get_library(lib_name, kernel_source.str()); + } + return d.get_kernel(kernel_name, lib); +} + MTL::ComputePipelineState* get_unary_kernel( metal::Device& d, const std::string& kernel_name, @@ -121,4 +142,138 @@ MTL::ComputePipelineState* get_copy_kernel( return d.get_kernel(kernel_name, lib); } +MTL::ComputePipelineState* get_softmax_kernel( + metal::Device& d, + const std::string& kernel_name, + bool precise, + const array& out) { + std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); + auto lib = d.get_library(lib_name); + if (lib == nullptr) { + std::ostringstream kernel_source; + kernel_source << metal::utils() << metal::softmax() + << fmt::format( + softmax_kernels, + lib_name, + get_type_string(out.dtype()), + get_type_string(precise ? float32 : out.dtype())); + lib = d.get_library(lib_name, kernel_source.str()); + } + return d.get_kernel(kernel_name, lib); +} + +MTL::ComputePipelineState* get_scan_kernel( + metal::Device& d, + const std::string& kernel_name, + bool reverse, + bool inclusive, + const array& in, + const array& out) { + std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); + auto lib = d.get_library(lib_name); + if (lib == nullptr) { + std::ostringstream kernel_source; + kernel_source << metal::utils() << metal::scan() + << fmt::format( + scan_kernels, + lib_name, + get_type_string(in.dtype()), + get_type_string(out.dtype()), + op_name(out), + inclusive, + reverse); + lib = d.get_library(lib_name, kernel_source.str()); + } + return d.get_kernel(kernel_name, lib); +} + +MTL::ComputePipelineState* get_sort_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& out, + int bn, + int tn) { + std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); + auto lib = d.get_library(lib_name); + if (lib == nullptr) { + std::ostringstream kernel_source; + kernel_source << metal::utils() << metal::sort() + << fmt::format( + block_sort_kernels, + lib_name, + get_type_string(in.dtype()), + get_type_string(out.dtype()), + bn, + tn); + lib = d.get_library(lib_name, kernel_source.str()); + } + return d.get_kernel(kernel_name, lib); +} + +MTL::ComputePipelineState* get_mb_sort_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& idx, + int bn, + int tn) { + std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); + auto lib = d.get_library(lib_name); + if (lib == nullptr) { + std::ostringstream kernel_source; + kernel_source << metal::utils() << metal::sort() + << fmt::format( + multiblock_sort_kernels, + lib_name, + get_type_string(in.dtype()), + get_type_string(idx.dtype()), + bn, + tn); + lib = d.get_library(lib_name, kernel_source.str()); + } + return d.get_kernel(kernel_name, lib); +} + +MTL::ComputePipelineState* get_reduce_init_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out) { + auto lib = d.get_library(kernel_name); + if (lib == nullptr) { + std::ostringstream kernel_source; + kernel_source << metal::utils() << metal::reduce_utils() + << fmt::format( + reduce_init_kernels, + kernel_name, + get_type_string(out.dtype()), + op_name(out)); + lib = d.get_library(kernel_name, kernel_source.str()); + } + return d.get_kernel(kernel_name, lib); +} + +MTL::ComputePipelineState* get_reduce_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& out) { + std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); + auto lib = d.get_library(lib_name); + if (lib == nullptr) { + bool non_atomic = out.dtype() == int64 || out.dtype() == uint64; + std::ostringstream kernel_source; + kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce() + << fmt::format( + non_atomic ? reduce_non_atomic_kernels + : reduce_kernels, + lib_name, + get_type_string(in.dtype()), + get_type_string(out.dtype()), + op_name(out)); + lib = d.get_library(lib_name, kernel_source.str()); + } + return d.get_kernel(kernel_name, lib); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index e7afc04b3..82ce8e5ea 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -5,6 +5,11 @@ namespace mlx::core { +MTL::ComputePipelineState* get_arange_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out); + MTL::ComputePipelineState* get_unary_kernel( metal::Device& d, const std::string& kernel_name, @@ -33,4 +38,45 @@ MTL::ComputePipelineState* get_copy_kernel( const array& in, const array& out); +MTL::ComputePipelineState* get_softmax_kernel( + metal::Device& d, + const std::string& kernel_name, + bool precise, + const array& out); + +MTL::ComputePipelineState* get_scan_kernel( + metal::Device& d, + const std::string& kernel_name, + bool reverse, + bool inclusive, + const array& in, + const array& out); + +MTL::ComputePipelineState* get_sort_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& out, + int bn, + int tn); + +MTL::ComputePipelineState* get_mb_sort_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& idx, + int bn, + int tn); + +MTL::ComputePipelineState* get_reduce_init_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out); + +MTL::ComputePipelineState* get_reduce_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& out); + } // namespace mlx::core diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 0e29c0650..841c4800d 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -8,9 +8,9 @@ set( ${CMAKE_CURRENT_SOURCE_DIR}/utils.h ) + set( KERNELS - "arange" "arg_reduce" "conv" "fft" @@ -20,31 +20,42 @@ set( "rms_norm" "layer_norm" "rope" - "scan" "scaled_dot_product_attention" - "softmax" - "sort" ) if (NOT MLX_METAL_JIT) set( KERNELS ${KERNELS} + "arange" "binary" "binary_two" "unary" "ternary" "copy" + "softmax" + "sort" + "scan" + "reduce" ) set( HEADERS ${HEADERS} + arange.h unary_ops.h unary.h binary_ops.h binary.h ternary.h copy.h + softmax.h + sort.h + scan.h + reduction/ops.h + reduction/reduce_init.h + reduction/reduce_all.h + reduction/reduce_col.h + reduction/reduce_row.h ) endif() @@ -87,15 +98,6 @@ foreach(KERNEL ${STEEL_KERNELS}) set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR}) endforeach() -file(GLOB_RECURSE REDUCE_KERNELS ${CMAKE_CURRENT_SOURCE_DIR}/reduction/*.metal) -file(GLOB_RECURSE REDUCE_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/reduction/*.h) - -foreach(KERNEL ${REDUCE_KERNELS}) - cmake_path(GET KERNEL STEM TARGET) - build_kernel_base(${TARGET} ${KERNEL} "${REDUCE_HEADERS}") - set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR}) -endforeach() - add_custom_command( OUTPUT ${MLX_METAL_PATH}/mlx.metallib COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib diff --git a/mlx/backend/metal/kernels/arange.h b/mlx/backend/metal/kernels/arange.h new file mode 100644 index 000000000..5448fe9aa --- /dev/null +++ b/mlx/backend/metal/kernels/arange.h @@ -0,0 +1,9 @@ +// Copyright © 2023-2024 Apple Inc. +template +[[kernel]] void arange( + constant const T& start, + constant const T& step, + device T* out, + uint index [[thread_position_in_grid]]) { + out[index] = start + index * step; +} diff --git a/mlx/backend/metal/kernels/arange.metal b/mlx/backend/metal/kernels/arange.metal index c333414f5..ebc81c630 100644 --- a/mlx/backend/metal/kernels/arange.metal +++ b/mlx/backend/metal/kernels/arange.metal @@ -1,15 +1,8 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. +// clang-format off #include "mlx/backend/metal/kernels/bf16.h" - -template -[[kernel]] void arange( - constant const T& start, - constant const T& step, - device T* out, - uint index [[thread_position_in_grid]]) { - out[index] = start + index * step; -} +#include "mlx/backend/metal/kernels/arange.h" #define instantiate_arange(tname, type) \ template [[host_name("arange" #tname)]] [[kernel]] void arange( \ @@ -18,7 +11,6 @@ template device type* out, \ uint index [[thread_position_in_grid]]); -// clang-format off instantiate_arange(uint8, uint8_t) instantiate_arange(uint16, uint16_t) instantiate_arange(uint32, uint32_t) diff --git a/mlx/backend/metal/kernels/atomic.h b/mlx/backend/metal/kernels/atomic.h index 7ee6ac294..5f526de04 100644 --- a/mlx/backend/metal/kernels/atomic.h +++ b/mlx/backend/metal/kernels/atomic.h @@ -2,11 +2,8 @@ #pragma once -#ifndef MLX_METAL_JIT #include #include -#include "mlx/backend/metal/kernels/bf16.h" -#endif using namespace metal; diff --git a/mlx/backend/metal/kernels/defines.h b/mlx/backend/metal/kernels/defines.h index 7f7bb49ed..3c4fcbdeb 100644 --- a/mlx/backend/metal/kernels/defines.h +++ b/mlx/backend/metal/kernels/defines.h @@ -2,7 +2,7 @@ #pragma once -#ifdef __METAL__ +#if defined __METAL__ || defined MLX_METAL_JIT #define MTL_CONST constant #else #define MTL_CONST @@ -11,6 +11,5 @@ static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4; static MTL_CONST constexpr int REDUCE_N_READS = 16; static MTL_CONST constexpr int SOFTMAX_N_READS = 4; -static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096; static MTL_CONST constexpr int RMS_N_READS = 4; static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096; diff --git a/mlx/backend/metal/kernels/reduce.h b/mlx/backend/metal/kernels/reduce.h new file mode 100644 index 000000000..649db0adb --- /dev/null +++ b/mlx/backend/metal/kernels/reduce.h @@ -0,0 +1,4 @@ +#pragma once +#include "mlx/backend/metal/kernels/reduction/reduce_all.h" +#include "mlx/backend/metal/kernels/reduction/reduce_col.h" +#include "mlx/backend/metal/kernels/reduction/reduce_row.h" diff --git a/mlx/backend/metal/kernels/reduce.metal b/mlx/backend/metal/kernels/reduce.metal new file mode 100644 index 000000000..b8d94df06 --- /dev/null +++ b/mlx/backend/metal/kernels/reduce.metal @@ -0,0 +1,293 @@ +// Copyright © 2024 Apple Inc. + +#include +#include + +// clang-format off +#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/atomic.h" +#include "mlx/backend/metal/kernels/reduction/ops.h" +#include "mlx/backend/metal/kernels/reduction/reduce_init.h" +#include "mlx/backend/metal/kernels/reduce.h" + +#define instantiate_reduce_helper_floats(inst_f, name, op) \ + inst_f(name, float16, half, op) \ + inst_f(name, float32, float, op) \ + inst_f(name, bfloat16, bfloat16_t, op) + +#define instantiate_reduce_helper_uints(inst_f, name, op) \ + inst_f(name, uint8, uint8_t, op) \ + inst_f(name, uint16, uint16_t, op) \ + inst_f(name, uint32, uint32_t, op) + +#define instantiate_reduce_helper_ints(inst_f, name, op) \ + inst_f(name, int8, int8_t, op) \ + inst_f(name, int16, int16_t, op) \ + inst_f(name, int32, int32_t, op) + +#define instantiate_reduce_helper_64b(inst_f, name, op) \ + inst_f(name, int64, int64_t, op) \ + inst_f(name, uint64, uint64_t, op) + +#define instantiate_reduce_helper_types(inst_f, name, op) \ + instantiate_reduce_helper_floats(inst_f, name, op) \ + instantiate_reduce_helper_uints(inst_f, name, op) \ + instantiate_reduce_helper_ints(inst_f, name, op) + +#define instantiate_reduce_ops(inst_f, type_f) \ + type_f(inst_f, sum, Sum) \ + type_f(inst_f, prod, Prod) \ + type_f(inst_f, min_, Min) \ + type_f(inst_f, max_, Max) + +// Special case for bool reductions +#define instantiate_reduce_from_types_helper( \ + inst_f, name, tname, itype, otype, op) \ + inst_f(name##tname, itype, otype, op) + +#define instantiate_reduce_from_types(inst_f, name, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, name, bool_, bool, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, name, uint8, uint8_t, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, name, uint16, uint16_t, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, name, uint32, uint32_t, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, name, uint64, uint64_t, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, name, int8, int8_t, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, name, int16, int16_t, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, name, int32, int32_t, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, name, int64, int64_t, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, name, float16, half, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, \ + name, \ + float32, \ + float, \ + otype, \ + op) \ + instantiate_reduce_from_types_helper( \ + inst_f, \ + name, \ + bfloat16, \ + bfloat16_t, \ + otype, \ + op) + +#define instantiate_init_reduce(name, otype, op) \ + template [[host_name("i_reduce_" #name)]] [[kernel]] void \ + init_reduce( \ + device otype * out [[buffer(1)]], uint tid [[thread_position_in_grid]]); + +#define instantiate_init_reduce_helper(name, tname, type, op) \ + instantiate_init_reduce(name##tname, type, op) + +instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types) +instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b) + +instantiate_init_reduce(andbool_, bool, And) +instantiate_init_reduce(orbool_, bool, Or) + +#define instantiate_all_reduce(name, itype, otype, op) \ + template [[host_name("all_reduce_" #name)]] [[kernel]] void \ + all_reduce( \ + const device itype* in [[buffer(0)]], \ + device mlx_atomic* out [[buffer(1)]], \ + const device size_t& in_size [[buffer(2)]], \ + uint gid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint grid_size [[threads_per_grid]], \ + uint simd_per_group [[simdgroups_per_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); + +#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \ + template [[host_name("allNoAtomics_reduce_" #name)]] [[kernel]] void \ + all_reduce_no_atomics( \ + const device itype* in [[buffer(0)]], \ + device otype* out [[buffer(1)]], \ + const device size_t& in_size [[buffer(2)]], \ + uint gid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint grid_size [[threads_per_grid]], \ + uint simd_per_group [[simdgroups_per_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint thread_group_id [[threadgroup_position_in_grid]]); + +#define instantiate_same_all_reduce_helper(name, tname, type, op) \ + instantiate_all_reduce(name##tname, type, type, op) + +#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \ + instantiate_all_reduce_no_atomics(name##tname, type, type, op) + +instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types) +instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b) + +instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And) +instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or) + +// special case bool with larger output type +instantiate_all_reduce(sumbool_, bool, uint32_t, Sum) + +#define instantiate_col_reduce_general(name, itype, otype, op) \ + template [[host_name("colGeneral_reduce_" #name)]] [[kernel]] void \ + col_reduce_general( \ + const device itype* in [[buffer(0)]], \ + device mlx_atomic* out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& reduction_stride [[buffer(3)]], \ + const constant size_t& out_size [[buffer(4)]], \ + const constant int* shape [[buffer(5)]], \ + const constant size_t* strides [[buffer(6)]], \ + const constant int& ndim [[buffer(7)]], \ + threadgroup otype* local_data [[threadgroup(0)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 lsize [[threads_per_threadgroup]]); + +#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \ + template \ + [[host_name("colGeneralNoAtomics_reduce_" #name)]] [[kernel]] void \ + col_reduce_general_no_atomics( \ + const device itype* in [[buffer(0)]], \ + device otype* out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& reduction_stride [[buffer(3)]], \ + const constant size_t& out_size [[buffer(4)]], \ + const constant int* shape [[buffer(5)]], \ + const constant size_t* strides [[buffer(6)]], \ + const constant int& ndim [[buffer(7)]], \ + threadgroup otype* local_data [[threadgroup(0)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 gid [[thread_position_in_grid]], \ + uint3 lsize [[threads_per_threadgroup]], \ + uint3 gsize [[threads_per_grid]]); + +#define instantiate_col_reduce_small(name, itype, otype, op) \ + template [[host_name("colSmall_reduce_" #name)]] [[kernel]] void \ + col_reduce_small( \ + const device itype* in [[buffer(0)]], \ + device otype* out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& reduction_stride [[buffer(3)]], \ + const constant size_t& out_size [[buffer(4)]], \ + const constant int* shape [[buffer(5)]], \ + const constant size_t* strides [[buffer(6)]], \ + const constant int& ndim [[buffer(7)]], \ + const constant size_t& non_col_reductions [[buffer(8)]], \ + const constant int* non_col_shapes [[buffer(9)]], \ + const constant size_t* non_col_strides [[buffer(10)]], \ + const constant int& non_col_ndim [[buffer(11)]], \ + uint tid [[thread_position_in_grid]]); + +#define instantiate_same_col_reduce_helper(name, tname, type, op) \ + instantiate_col_reduce_small(name ##tname, type, type, op) \ + instantiate_col_reduce_general(name ##tname, type, type, op) + +#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \ + instantiate_col_reduce_small(name ##tname, type, type, op) \ + instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op) + +instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types) +instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b) + +instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum) +instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And) +instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or) + +instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum) +instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And) +instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or) + +#define instantiate_row_reduce_small(name, itype, otype, op) \ + template [[host_name("rowGeneralSmall_reduce_" #name)]] [[kernel]] void \ + row_reduce_general_small( \ + const device itype* in [[buffer(0)]], \ + device otype* out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& out_size [[buffer(3)]], \ + const constant size_t& non_row_reductions [[buffer(4)]], \ + const constant int* shape [[buffer(5)]], \ + const constant size_t* strides [[buffer(6)]], \ + const constant int& ndim [[buffer(7)]], \ + uint lid [[thread_position_in_grid]]); \ + template [[host_name("rowGeneralMed_reduce_" #name)]] [[kernel]] void \ + row_reduce_general_med( \ + const device itype* in [[buffer(0)]], \ + device otype* out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& out_size [[buffer(3)]], \ + const constant size_t& non_row_reductions [[buffer(4)]], \ + const constant int* shape [[buffer(5)]], \ + const constant size_t* strides [[buffer(6)]], \ + const constant int& ndim [[buffer(7)]], \ + uint tid [[threadgroup_position_in_grid]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); + +#define instantiate_row_reduce_general(name, itype, otype, op) \ + instantiate_row_reduce_small(name, itype, otype, op) \ + template \ + [[host_name("rowGeneral_reduce_" #name)]] [[kernel]] void \ + row_reduce_general( \ + const device itype* in [[buffer(0)]], \ + device mlx_atomic* out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& out_size [[buffer(3)]], \ + const constant size_t& non_row_reductions [[buffer(4)]], \ + const constant int* shape [[buffer(5)]], \ + const constant size_t* strides [[buffer(6)]], \ + const constant int& ndim [[buffer(7)]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 lsize [[threads_per_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_per_group [[simdgroups_per_threadgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); + +#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \ + instantiate_row_reduce_small(name, itype, otype, op) \ + template \ + [[host_name("rowGeneralNoAtomics_reduce_" #name)]] [[kernel]] void \ + row_reduce_general_no_atomics( \ + const device itype* in [[buffer(0)]], \ + device otype* out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& out_size [[buffer(3)]], \ + const constant size_t& non_row_reductions [[buffer(4)]], \ + const constant int* shape [[buffer(5)]], \ + const constant size_t* strides [[buffer(6)]], \ + const constant int& ndim [[buffer(7)]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 lsize [[threads_per_threadgroup]], \ + uint3 gsize [[threads_per_grid]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_per_group [[simdgroups_per_threadgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); + +#define instantiate_same_row_reduce_helper(name, tname, type, op) \ + instantiate_row_reduce_general(name##tname, type, type, op) + +#define instantiate_same_row_reduce_na_helper(name, tname, type, op) \ + instantiate_row_reduce_general_no_atomics(name##tname, type, type, op) + +instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types) +instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b) + +instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And) +instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or) + +instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum) + // clang-format on diff --git a/mlx/backend/metal/kernels/reduction.h b/mlx/backend/metal/kernels/reduce_utils.h similarity index 100% rename from mlx/backend/metal/kernels/reduction.h rename to mlx/backend/metal/kernels/reduce_utils.h diff --git a/mlx/backend/metal/kernels/reduction/kernels/reduce_init.metal b/mlx/backend/metal/kernels/reduction/kernels/reduce_init.metal deleted file mode 100644 index f0ca1c8b4..000000000 --- a/mlx/backend/metal/kernels/reduction/kernels/reduce_init.metal +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#include "mlx/backend/metal/kernels/reduction/ops.h" -#include "mlx/backend/metal/kernels/reduction/reduce_inst.h" -#include "mlx/backend/metal/kernels/reduction/utils.h" - -using namespace metal; - -/////////////////////////////////////////////////////////////////////////////// -// Reduce init -/////////////////////////////////////////////////////////////////////////////// - -template -[[kernel]] void init_reduce( - device T* out [[buffer(0)]], - uint tid [[thread_position_in_grid]]) { - out[tid] = Op::init; -} - -#define instantiate_init_reduce(name, otype, op) \ - template [[host_name("i" #name)]] [[kernel]] void init_reduce( \ - device otype * out [[buffer(1)]], uint tid [[thread_position_in_grid]]); - -#define instantiate_init_reduce_helper(name, tname, type, op) \ - instantiate_init_reduce(name##tname, type, op) - -// clang-format off -instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types) -instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b) - -instantiate_init_reduce(andbool_, bool, And) -instantiate_init_reduce(orbool_, bool, Or) // clang-format on \ No newline at end of file diff --git a/mlx/backend/metal/kernels/reduction/ops.h b/mlx/backend/metal/kernels/reduction/ops.h index 5996c2071..dc6d17b49 100644 --- a/mlx/backend/metal/kernels/reduction/ops.h +++ b/mlx/backend/metal/kernels/reduction/ops.h @@ -5,11 +5,7 @@ #include #include -#ifndef MLX_METAL_JIT -#include "mlx/backend/metal/kernels/atomic.h" -#include "mlx/backend/metal/kernels/bf16.h" -#include "mlx/backend/metal/kernels/utils.h" -#endif +static constant constexpr const uint8_t simd_size = 32; union bool4_or_uint { bool4 b; @@ -23,6 +19,7 @@ struct None { } }; +template struct And { bool simd_reduce(bool val) { return simd_all(val); @@ -60,6 +57,7 @@ struct And { } }; +template struct Or { bool simd_reduce(bool val) { return simd_any(val); diff --git a/mlx/backend/metal/kernels/reduction/kernels/reduce_all.metal b/mlx/backend/metal/kernels/reduction/reduce_all.h similarity index 61% rename from mlx/backend/metal/kernels/reduction/kernels/reduce_all.metal rename to mlx/backend/metal/kernels/reduction/reduce_all.h index 83ed3e982..e5928b615 100644 --- a/mlx/backend/metal/kernels/reduction/kernels/reduce_all.metal +++ b/mlx/backend/metal/kernels/reduction/reduce_all.h @@ -1,11 +1,5 @@ // Copyright © 2023-2024 Apple Inc. -#include "mlx/backend/metal/kernels/reduction/ops.h" -#include "mlx/backend/metal/kernels/reduction/reduce_inst.h" -#include "mlx/backend/metal/kernels/reduction/utils.h" - -using namespace metal; - /////////////////////////////////////////////////////////////////////////////// // All reduce helper /////////////////////////////////////////////////////////////////////////////// @@ -139,50 +133,3 @@ template out[thread_group_id] = total_val; } } - -#define instantiate_all_reduce(name, itype, otype, op) \ - template [[host_name("all_reduce_" #name)]] [[kernel]] void \ - all_reduce( \ - const device itype* in [[buffer(0)]], \ - device mlx_atomic* out [[buffer(1)]], \ - const device size_t& in_size [[buffer(2)]], \ - uint gid [[thread_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint grid_size [[threads_per_grid]], \ - uint simd_per_group [[simdgroups_per_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); - -#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \ - template [[host_name("all_reduce_no_atomics_" #name)]] [[kernel]] void \ - all_reduce_no_atomics( \ - const device itype* in [[buffer(0)]], \ - device otype* out [[buffer(1)]], \ - const device size_t& in_size [[buffer(2)]], \ - uint gid [[thread_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint grid_size [[threads_per_grid]], \ - uint simd_per_group [[simdgroups_per_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ - uint thread_group_id [[threadgroup_position_in_grid]]); - -/////////////////////////////////////////////////////////////////////////////// -// Instantiations -/////////////////////////////////////////////////////////////////////////////// - -#define instantiate_same_all_reduce_helper(name, tname, type, op) \ - instantiate_all_reduce(name##tname, type, type, op) - -#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \ - instantiate_all_reduce_no_atomics(name##tname, type, type, op) - -// clang-format off -instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types) -instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b) - -instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And) -instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or) - -// special case bool with larger output type -instantiate_all_reduce(sumbool_, bool, uint32_t, Sum) // clang-format on \ No newline at end of file diff --git a/mlx/backend/metal/kernels/reduction/kernels/reduce_col.metal b/mlx/backend/metal/kernels/reduction/reduce_col.h similarity index 52% rename from mlx/backend/metal/kernels/reduction/kernels/reduce_col.metal rename to mlx/backend/metal/kernels/reduction/reduce_col.h index 7e28b11ca..568d0cebc 100644 --- a/mlx/backend/metal/kernels/reduction/kernels/reduce_col.metal +++ b/mlx/backend/metal/kernels/reduction/reduce_col.h @@ -1,11 +1,5 @@ // Copyright © 2023-2024 Apple Inc. -#include "mlx/backend/metal/kernels/reduction/ops.h" -#include "mlx/backend/metal/kernels/reduction/reduce_inst.h" -#include "mlx/backend/metal/kernels/reduction/utils.h" - -using namespace metal; - /////////////////////////////////////////////////////////////////////////////// // Small column reduce kernel /////////////////////////////////////////////////////////////////////////////// @@ -52,23 +46,6 @@ template out[out_idx] = total_val; } -#define instantiate_col_reduce_small(name, itype, otype, op) \ - template [[host_name("col_reduce_small_" #name)]] [[kernel]] void \ - col_reduce_small( \ - const device itype* in [[buffer(0)]], \ - device otype* out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& reduction_stride [[buffer(3)]], \ - const constant size_t& out_size [[buffer(4)]], \ - const constant int* shape [[buffer(5)]], \ - const constant size_t* strides [[buffer(6)]], \ - const constant int& ndim [[buffer(7)]], \ - const constant size_t& non_col_reductions [[buffer(8)]], \ - const constant int* non_col_shapes [[buffer(9)]], \ - const constant size_t* non_col_strides [[buffer(10)]], \ - const constant int& non_col_ndim [[buffer(11)]], \ - uint tid [[thread_position_in_grid]]); - /////////////////////////////////////////////////////////////////////////////// // Column reduce helper /////////////////////////////////////////////////////////////////////////////// @@ -186,64 +163,3 @@ template } } } - -#define instantiate_col_reduce_general(name, itype, otype, op) \ - template [[host_name("col_reduce_general_" #name)]] [[kernel]] void \ - col_reduce_general( \ - const device itype* in [[buffer(0)]], \ - device mlx_atomic* out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& reduction_stride [[buffer(3)]], \ - const constant size_t& out_size [[buffer(4)]], \ - const constant int* shape [[buffer(5)]], \ - const constant size_t* strides [[buffer(6)]], \ - const constant int& ndim [[buffer(7)]], \ - threadgroup otype* local_data [[threadgroup(0)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint3 lsize [[threads_per_threadgroup]]); - -#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \ - template \ - [[host_name("col_reduce_general_no_atomics_" #name)]] [[kernel]] void \ - col_reduce_general_no_atomics( \ - const device itype* in [[buffer(0)]], \ - device otype* out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& reduction_stride [[buffer(3)]], \ - const constant size_t& out_size [[buffer(4)]], \ - const constant int* shape [[buffer(5)]], \ - const constant size_t* strides [[buffer(6)]], \ - const constant int& ndim [[buffer(7)]], \ - threadgroup otype* local_data [[threadgroup(0)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint3 gid [[thread_position_in_grid]], \ - uint3 lsize [[threads_per_threadgroup]], \ - uint3 gsize [[threads_per_grid]]); - -/////////////////////////////////////////////////////////////////////////////// -// Instantiations -/////////////////////////////////////////////////////////////////////////////// - -// clang-format off -#define instantiate_same_col_reduce_helper(name, tname, type, op) \ - instantiate_col_reduce_small(name ##tname, type, type, op) \ - instantiate_col_reduce_general(name ##tname, type, type, op) // clang-format on - -// clang-format off -#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \ - instantiate_col_reduce_small(name ##tname, type, type, op) \ - instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op) // clang-format on - -// clang-format off -instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types) -instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b) - -instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum) -instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And) -instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or) - -instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum) -instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And) -instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or) // clang-format on \ No newline at end of file diff --git a/mlx/backend/metal/kernels/reduction/reduce_init.h b/mlx/backend/metal/kernels/reduction/reduce_init.h new file mode 100644 index 000000000..604efa785 --- /dev/null +++ b/mlx/backend/metal/kernels/reduction/reduce_init.h @@ -0,0 +1,8 @@ +// Copyright © 2023-2024 Apple Inc. + +template +[[kernel]] void init_reduce( + device T* out [[buffer(0)]], + uint tid [[thread_position_in_grid]]) { + out[tid] = Op::init; +} diff --git a/mlx/backend/metal/kernels/reduction/reduce_inst.h b/mlx/backend/metal/kernels/reduction/reduce_inst.h deleted file mode 100644 index ff45290e7..000000000 --- a/mlx/backend/metal/kernels/reduction/reduce_inst.h +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#pragma once - -#include -#include - -#include "mlx/backend/metal/kernels/defines.h" -#include "mlx/backend/metal/kernels/reduction/ops.h" - -// clang-format off -#define instantiate_reduce_helper_floats(inst_f, name, op) \ - inst_f(name, float16, half, op) inst_f(name, float32, float, op) \ - inst_f(name, bfloat16, bfloat16_t, op) - -#define instantiate_reduce_helper_uints(inst_f, name, op) \ - inst_f(name, uint8, uint8_t, op) inst_f(name, uint16, uint16_t, op) \ - inst_f(name, uint32, uint32_t, op) - -#define instantiate_reduce_helper_ints(inst_f, name, op) \ - inst_f(name, int8, int8_t, op) inst_f(name, int16, int16_t, op) \ - inst_f(name, int32, int32_t, op) - -#define instantiate_reduce_helper_64b(inst_f, name, op) \ - inst_f(name, int64, int64_t, op) inst_f(name, uint64, uint64_t, op) - -#define instantiate_reduce_helper_types(inst_f, name, op) \ - instantiate_reduce_helper_floats(inst_f, name, op) \ - instantiate_reduce_helper_uints(inst_f, name, op) \ - instantiate_reduce_helper_ints(inst_f, name, op) - -#define instantiate_reduce_ops(inst_f, type_f) \ - type_f(inst_f, sum, Sum) type_f(inst_f, prod, Prod) \ - type_f(inst_f, min_, Min) type_f(inst_f, max_, Max) - -// Special case for bool reductions -#define instantiate_reduce_from_types_helper( \ - inst_f, name, tname, itype, otype, op) \ - inst_f(name##tname, itype, otype, op) - -#define instantiate_reduce_from_types(inst_f, name, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, name, bool_, bool, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, name, uint8, uint8_t, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, name, uint16, uint16_t, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, name, uint32, uint32_t, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, name, int8, int8_t, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, name, int16, int16_t, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, name, int32, int32_t, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, name, int64, int64_t, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, name, float16, half, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, \ - name, \ - float32, \ - float, \ - otype, \ - op) \ - instantiate_reduce_from_types_helper( \ - inst_f, \ - name, \ - bfloat16, \ - bfloat16_t, \ - otype, \ - op) -// clang-format on diff --git a/mlx/backend/metal/kernels/reduction/kernels/reduce_row.metal b/mlx/backend/metal/kernels/reduction/reduce_row.h similarity index 60% rename from mlx/backend/metal/kernels/reduction/kernels/reduce_row.metal rename to mlx/backend/metal/kernels/reduction/reduce_row.h index 22810ca20..cefaf63e9 100644 --- a/mlx/backend/metal/kernels/reduction/kernels/reduce_row.metal +++ b/mlx/backend/metal/kernels/reduction/reduce_row.h @@ -1,11 +1,5 @@ // Copyright © 2023-2024 Apple Inc. -#include "mlx/backend/metal/kernels/reduction/ops.h" -#include "mlx/backend/metal/kernels/reduction/reduce_inst.h" -#include "mlx/backend/metal/kernels/reduction/utils.h" - -using namespace metal; - /////////////////////////////////////////////////////////////////////////////// // Small row reductions /////////////////////////////////////////////////////////////////////////////// @@ -123,33 +117,6 @@ template } } -#define instantiate_row_reduce_small(name, itype, otype, op) \ - template [[host_name("row_reduce_general_small_" #name)]] [[kernel]] void \ - row_reduce_general_small( \ - const device itype* in [[buffer(0)]], \ - device otype* out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& out_size [[buffer(3)]], \ - const constant size_t& non_row_reductions [[buffer(4)]], \ - const constant int* shape [[buffer(5)]], \ - const constant size_t* strides [[buffer(6)]], \ - const constant int& ndim [[buffer(7)]], \ - uint lid [[thread_position_in_grid]]); \ - template [[host_name("row_reduce_general_med_" #name)]] [[kernel]] void \ - row_reduce_general_med( \ - const device itype* in [[buffer(0)]], \ - device otype* out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& out_size [[buffer(3)]], \ - const constant size_t& non_row_reductions [[buffer(4)]], \ - const constant int* shape [[buffer(5)]], \ - const constant size_t* strides [[buffer(6)]], \ - const constant int& ndim [[buffer(7)]], \ - uint tid [[threadgroup_position_in_grid]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); - /////////////////////////////////////////////////////////////////////////////// // Large row reductions /////////////////////////////////////////////////////////////////////////////// @@ -318,61 +285,3 @@ template out[(ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val; } } - -#define instantiate_row_reduce_general(name, itype, otype, op) \ - instantiate_row_reduce_small(name, itype, otype, op) template \ - [[host_name("row_reduce_general_" #name)]] [[kernel]] void \ - row_reduce_general( \ - const device itype* in [[buffer(0)]], \ - device mlx_atomic* out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& out_size [[buffer(3)]], \ - const constant size_t& non_row_reductions [[buffer(4)]], \ - const constant int* shape [[buffer(5)]], \ - const constant size_t* strides [[buffer(6)]], \ - const constant int& ndim [[buffer(7)]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint3 lsize [[threads_per_threadgroup]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_per_group [[simdgroups_per_threadgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); - -#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \ - instantiate_row_reduce_small(name, itype, otype, op) template \ - [[host_name("row_reduce_general_no_atomics_" #name)]] [[kernel]] void \ - row_reduce_general_no_atomics( \ - const device itype* in [[buffer(0)]], \ - device otype* out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& out_size [[buffer(3)]], \ - const constant size_t& non_row_reductions [[buffer(4)]], \ - const constant int* shape [[buffer(5)]], \ - const constant size_t* strides [[buffer(6)]], \ - const constant int& ndim [[buffer(7)]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint3 lsize [[threads_per_threadgroup]], \ - uint3 gsize [[threads_per_grid]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_per_group [[simdgroups_per_threadgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); - -/////////////////////////////////////////////////////////////////////////////// -// Instantiations -/////////////////////////////////////////////////////////////////////////////// - -#define instantiate_same_row_reduce_helper(name, tname, type, op) \ - instantiate_row_reduce_general(name##tname, type, type, op) - -#define instantiate_same_row_reduce_na_helper(name, tname, type, op) \ - instantiate_row_reduce_general_no_atomics(name##tname, type, type, op) - -// clang-format off -instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types) -instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b) - -instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And) -instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or) - -instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum) // clang-format on \ No newline at end of file diff --git a/mlx/backend/metal/kernels/reduction/utils.h b/mlx/backend/metal/kernels/reduction/utils.h deleted file mode 100644 index 6665a3eca..000000000 --- a/mlx/backend/metal/kernels/reduction/utils.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#pragma once - -#include -#include - -#include "mlx/backend/metal/kernels/defines.h" -#include "mlx/backend/metal/kernels/steel/utils.h" -#include "mlx/backend/metal/kernels/utils.h" - -#include "mlx/backend/metal/kernels/reduction/ops.h" - -static constant constexpr const uint8_t simd_size = 32; \ No newline at end of file diff --git a/mlx/backend/metal/kernels/scan.h b/mlx/backend/metal/kernels/scan.h new file mode 100644 index 000000000..19e4be86c --- /dev/null +++ b/mlx/backend/metal/kernels/scan.h @@ -0,0 +1,440 @@ +// Copyright © 2023-2024 Apple Inc. + +template +struct CumSum { + static constexpr constant U init = static_cast(0); + + template + U operator()(U a, T b) { + return a + b; + } + + U simd_scan(U x) { + return simd_prefix_inclusive_sum(x); + } + + U simd_exclusive_scan(U x) { + return simd_prefix_exclusive_sum(x); + } +}; + +template +struct CumProd { + static constexpr constant U init = static_cast(1.0f); + + template + U operator()(U a, T b) { + return a * b; + } + + U simd_scan(U x) { + return simd_prefix_inclusive_product(x); + } + + U simd_exclusive_scan(U x) { + return simd_prefix_exclusive_product(x); + } +}; + +template <> +struct CumProd { + static constexpr constant bool init = true; + + template + bool operator()(bool a, T b) { + return a & static_cast(b); + } + + bool simd_scan(bool x) { + for (int i = 1; i <= 16; i *= 2) { + bool other = simd_shuffle_up(x, i); + x &= other; + } + return x; + } + + bool simd_exclusive_scan(bool x) { + x = simd_scan(x); + return simd_shuffle_and_fill_up(x, init, 1); + } +}; + +template +struct CumMax { + static constexpr constant U init = Limits::min; + + template + U operator()(U a, T b) { + return (a >= b) ? a : b; + } + + U simd_scan(U x) { + for (int i = 1; i <= 16; i *= 2) { + U other = simd_shuffle_up(x, i); + x = (x >= other) ? x : other; + } + return x; + } + + U simd_exclusive_scan(U x) { + x = simd_scan(x); + return simd_shuffle_and_fill_up(x, init, 1); + } +}; + +template +struct CumMin { + static constexpr constant U init = Limits::max; + + template + U operator()(U a, T b) { + return (a <= b) ? a : b; + } + + U simd_scan(U x) { + for (int i = 1; i <= 16; i *= 2) { + U other = simd_shuffle_up(x, i); + x = (x <= other) ? x : other; + } + return x; + } + + U simd_exclusive_scan(U x) { + x = simd_scan(x); + return simd_shuffle_and_fill_up(x, init, 1); + } +}; + +template +inline void load_unsafe(U values[N_READS], const device T* input) { + if (reverse) { + for (int i = 0; i < N_READS; i++) { + values[N_READS - i - 1] = input[i]; + } + } else { + for (int i = 0; i < N_READS; i++) { + values[i] = input[i]; + } + } +} + +template +inline void load_safe( + U values[N_READS], + const device T* input, + int start, + int total, + U init) { + if (reverse) { + for (int i = 0; i < N_READS; i++) { + values[N_READS - i - 1] = + (start + N_READS - i - 1 < total) ? input[i] : init; + } + } else { + for (int i = 0; i < N_READS; i++) { + values[i] = (start + i < total) ? input[i] : init; + } + } +} + +template +inline void write_unsafe(U values[N_READS], device U* out) { + if (reverse) { + for (int i = 0; i < N_READS; i++) { + out[i] = values[N_READS - i - 1]; + } + } else { + for (int i = 0; i < N_READS; i++) { + out[i] = values[i]; + } + } +} + +template +inline void write_safe(U values[N_READS], device U* out, int start, int total) { + if (reverse) { + for (int i = 0; i < N_READS; i++) { + if (start + N_READS - i - 1 < total) { + out[i] = values[N_READS - i - 1]; + } + } + } else { + for (int i = 0; i < N_READS; i++) { + if (start + i < total) { + out[i] = values[i]; + } + } + } +} + +template < + typename T, + typename U, + typename Op, + int N_READS, + bool inclusive, + bool reverse> +[[kernel]] void contiguous_scan( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& axis_size [[buffer(2)]], + uint gid [[thread_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint lsize [[threads_per_threadgroup]], + uint simd_size [[threads_per_simdgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + Op op; + + // Position the pointers + in += (gid / lsize) * axis_size; + out += (gid / lsize) * axis_size; + + // Compute the number of simd_groups + uint simd_groups = lsize / simd_size; + + // Allocate memory + U prefix = Op::init; + U values[N_READS]; + threadgroup U simdgroup_sums[32]; + + // Loop over the reduced axis in blocks of size ceildiv(axis_size, + // N_READS*lsize) + // Read block + // Compute inclusive scan of the block + // Compute inclusive scan per thread + // Compute exclusive scan of thread sums in simdgroup + // Write simdgroup sums in SM + // Compute exclusive scan of simdgroup sums + // Compute the output by scanning prefix, prev_simdgroup, prev_thread, + // value + // Write block + + for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) { + // Compute the block offset + uint offset = r * lsize * N_READS + lid * N_READS; + + // Read the values + if (reverse) { + if ((offset + N_READS) < axis_size) { + load_unsafe( + values, in + axis_size - offset - N_READS); + } else { + load_safe( + values, + in + axis_size - offset - N_READS, + offset, + axis_size, + Op::init); + } + } else { + if ((offset + N_READS) < axis_size) { + load_unsafe(values, in + offset); + } else { + load_safe( + values, in + offset, offset, axis_size, Op::init); + } + } + + // Compute an inclusive scan per thread + for (int i = 1; i < N_READS; i++) { + values[i] = op(values[i], values[i - 1]); + } + + // Compute exclusive scan of thread sums + U prev_thread = op.simd_exclusive_scan(values[N_READS - 1]); + + // Write simdgroup_sums to SM + if (simd_lane_id == simd_size - 1) { + simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Compute exclusive scan of simdgroup_sums + if (simd_group_id == 0) { + U prev_simdgroup = op.simd_exclusive_scan(simdgroup_sums[simd_lane_id]); + simdgroup_sums[simd_lane_id] = prev_simdgroup; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Compute the output + for (int i = 0; i < N_READS; i++) { + values[i] = op(values[i], prefix); + values[i] = op(values[i], simdgroup_sums[simd_group_id]); + values[i] = op(values[i], prev_thread); + } + + // Write the values + if (reverse) { + if (inclusive) { + if ((offset + N_READS) < axis_size) { + write_unsafe( + values, out + axis_size - offset - N_READS); + } else { + write_safe( + values, out + axis_size - offset - N_READS, offset, axis_size); + } + } else { + if (lid == 0 && offset == 0) { + out[axis_size - 1] = Op::init; + } + if ((offset + N_READS + 1) < axis_size) { + write_unsafe( + values, out + axis_size - offset - 1 - N_READS); + } else { + write_safe( + values, + out + axis_size - offset - 1 - N_READS, + offset + 1, + axis_size); + } + } + } else { + if (inclusive) { + if ((offset + N_READS) < axis_size) { + write_unsafe(values, out + offset); + } else { + write_safe( + values, out + offset, offset, axis_size); + } + } else { + if (lid == 0 && offset == 0) { + out[0] = Op::init; + } + if ((offset + N_READS + 1) < axis_size) { + write_unsafe(values, out + offset + 1); + } else { + write_safe( + values, out + offset + 1, offset + 1, axis_size); + } + } + } + + // Share the prefix + if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) { + simdgroup_sums[0] = values[N_READS - 1]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + prefix = simdgroup_sums[0]; + } +} + +template < + typename T, + typename U, + typename Op, + int N_READS, + bool inclusive, + bool reverse> +[[kernel]] void strided_scan( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& axis_size [[buffer(2)]], + const constant size_t& stride [[buffer(3)]], + uint2 gid [[threadgroup_position_in_grid]], + uint2 lid [[thread_position_in_threadgroup]], + uint2 lsize [[threads_per_threadgroup]], + uint simd_size [[threads_per_simdgroup]]) { + Op op; + + // Allocate memory + threadgroup U read_buffer[N_READS * 32 * 32 + N_READS * 32]; + U values[N_READS]; + U prefix[N_READS]; + for (int i = 0; i < N_READS; i++) { + prefix[i] = Op::init; + } + + // Compute offsets + int offset = gid.y * axis_size * stride; + int global_index_x = gid.x * lsize.y * N_READS; + + for (uint j = 0; j < axis_size; j += simd_size) { + // Calculate the indices for the current thread + uint index_y = j + lid.y; + uint check_index_y = index_y; + uint index_x = global_index_x + lid.x * N_READS; + if (reverse) { + index_y = axis_size - 1 - index_y; + } + + // Read in SM + if (check_index_y < axis_size && (index_x + N_READS) < stride) { + for (int i = 0; i < N_READS; i++) { + read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = + in[offset + index_y * stride + index_x + i]; + } + } else { + for (int i = 0; i < N_READS; i++) { + if (check_index_y < axis_size && (index_x + i) < stride) { + read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = + in[offset + index_y * stride + index_x + i]; + } else { + read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = + Op::init; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Read strided into registers + for (int i = 0; i < N_READS; i++) { + values[i] = + read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i]; + } + // Do we need the following barrier? Shouldn't all simd threads execute + // simultaneously? + simdgroup_barrier(mem_flags::mem_threadgroup); + + // Perform the scan + for (int i = 0; i < N_READS; i++) { + values[i] = op.simd_scan(values[i]); + values[i] = op(values[i], prefix[i]); + prefix[i] = simd_shuffle(values[i], simd_size - 1); + } + + // Write to SM + for (int i = 0; i < N_READS; i++) { + read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] = + values[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write to device memory + if (!inclusive) { + if (check_index_y == 0) { + if ((index_x + N_READS) < stride) { + for (int i = 0; i < N_READS; i++) { + out[offset + index_y * stride + index_x + i] = Op::init; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((index_x + i) < stride) { + out[offset + index_y * stride + index_x + i] = Op::init; + } + } + } + } + if (reverse) { + index_y -= 1; + check_index_y += 1; + } else { + index_y += 1; + check_index_y += 1; + } + } + if (check_index_y < axis_size && (index_x + N_READS) < stride) { + for (int i = 0; i < N_READS; i++) { + out[offset + index_y * stride + index_x + i] = + read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i]; + } + } else { + for (int i = 0; i < N_READS; i++) { + if (check_index_y < axis_size && (index_x + i) < stride) { + out[offset + index_y * stride + index_x + i] = + read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i]; + } + } + } + } +} diff --git a/mlx/backend/metal/kernels/scan.metal b/mlx/backend/metal/kernels/scan.metal index 42533c7ea..2f026bc36 100644 --- a/mlx/backend/metal/kernels/scan.metal +++ b/mlx/backend/metal/kernels/scan.metal @@ -1,455 +1,19 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include #include -#include "mlx/backend/metal/kernels/defines.h" -#include "mlx/backend/metal/kernels/utils.h" +// clang-format off using namespace metal; -template -struct CumSum { - static constexpr constant U init = static_cast(0); - - template - U operator()(U a, T b) { - return a + b; - } - - U simd_scan(U x) { - return simd_prefix_inclusive_sum(x); - } - - U simd_exclusive_scan(U x) { - return simd_prefix_exclusive_sum(x); - } -}; - -template -struct CumProd { - static constexpr constant U init = static_cast(1.0f); - - template - U operator()(U a, T b) { - return a * b; - } - - U simd_scan(U x) { - return simd_prefix_inclusive_product(x); - } - - U simd_exclusive_scan(U x) { - return simd_prefix_exclusive_product(x); - } -}; - -template <> -struct CumProd { - static constexpr constant bool init = true; - - template - bool operator()(bool a, T b) { - return a & static_cast(b); - } - - bool simd_scan(bool x) { - for (int i = 1; i <= 16; i *= 2) { - bool other = simd_shuffle_up(x, i); - x &= other; - } - return x; - } - - bool simd_exclusive_scan(bool x) { - x = simd_scan(x); - return simd_shuffle_and_fill_up(x, init, 1); - } -}; - -template -struct CumMax { - static constexpr constant U init = Limits::min; - - template - U operator()(U a, T b) { - return (a >= b) ? a : b; - } - - U simd_scan(U x) { - for (int i = 1; i <= 16; i *= 2) { - U other = simd_shuffle_up(x, i); - x = (x >= other) ? x : other; - } - return x; - } - - U simd_exclusive_scan(U x) { - x = simd_scan(x); - return simd_shuffle_and_fill_up(x, init, 1); - } -}; - -template -struct CumMin { - static constexpr constant U init = Limits::max; - - template - U operator()(U a, T b) { - return (a <= b) ? a : b; - } - - U simd_scan(U x) { - for (int i = 1; i <= 16; i *= 2) { - U other = simd_shuffle_up(x, i); - x = (x <= other) ? x : other; - } - return x; - } - - U simd_exclusive_scan(U x) { - x = simd_scan(x); - return simd_shuffle_and_fill_up(x, init, 1); - } -}; - -template -inline void load_unsafe(U values[N_READS], const device T* input) { - if (reverse) { - for (int i = 0; i < N_READS; i++) { - values[N_READS - i - 1] = input[i]; - } - } else { - for (int i = 0; i < N_READS; i++) { - values[i] = input[i]; - } - } -} - -template -inline void load_safe( - U values[N_READS], - const device T* input, - int start, - int total, - U init) { - if (reverse) { - for (int i = 0; i < N_READS; i++) { - values[N_READS - i - 1] = - (start + N_READS - i - 1 < total) ? input[i] : init; - } - } else { - for (int i = 0; i < N_READS; i++) { - values[i] = (start + i < total) ? input[i] : init; - } - } -} - -template -inline void write_unsafe(U values[N_READS], device U* out) { - if (reverse) { - for (int i = 0; i < N_READS; i++) { - out[i] = values[N_READS - i - 1]; - } - } else { - for (int i = 0; i < N_READS; i++) { - out[i] = values[i]; - } - } -} - -template -inline void write_safe(U values[N_READS], device U* out, int start, int total) { - if (reverse) { - for (int i = 0; i < N_READS; i++) { - if (start + N_READS - i - 1 < total) { - out[i] = values[N_READS - i - 1]; - } - } - } else { - for (int i = 0; i < N_READS; i++) { - if (start + i < total) { - out[i] = values[i]; - } - } - } -} - -template < - typename T, - typename U, - typename Op, - int N_READS, - bool inclusive, - bool reverse> -[[kernel]] void contiguous_scan( - const device T* in [[buffer(0)]], - device U* out [[buffer(1)]], - const constant size_t& axis_size [[buffer(2)]], - uint gid [[thread_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint lsize [[threads_per_threadgroup]], - uint simd_size [[threads_per_simdgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - Op op; - - // Position the pointers - in += (gid / lsize) * axis_size; - out += (gid / lsize) * axis_size; - - // Compute the number of simd_groups - uint simd_groups = lsize / simd_size; - - // Allocate memory - U prefix = Op::init; - U values[N_READS]; - threadgroup U simdgroup_sums[32]; - - // Loop over the reduced axis in blocks of size ceildiv(axis_size, - // N_READS*lsize) - // Read block - // Compute inclusive scan of the block - // Compute inclusive scan per thread - // Compute exclusive scan of thread sums in simdgroup - // Write simdgroup sums in SM - // Compute exclusive scan of simdgroup sums - // Compute the output by scanning prefix, prev_simdgroup, prev_thread, - // value - // Write block - - for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) { - // Compute the block offset - uint offset = r * lsize * N_READS + lid * N_READS; - - // Read the values - if (reverse) { - if ((offset + N_READS) < axis_size) { - load_unsafe( - values, in + axis_size - offset - N_READS); - } else { - load_safe( - values, - in + axis_size - offset - N_READS, - offset, - axis_size, - Op::init); - } - } else { - if ((offset + N_READS) < axis_size) { - load_unsafe(values, in + offset); - } else { - load_safe( - values, in + offset, offset, axis_size, Op::init); - } - } - - // Compute an inclusive scan per thread - for (int i = 1; i < N_READS; i++) { - values[i] = op(values[i], values[i - 1]); - } - - // Compute exclusive scan of thread sums - U prev_thread = op.simd_exclusive_scan(values[N_READS - 1]); - - // Write simdgroup_sums to SM - if (simd_lane_id == simd_size - 1) { - simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]); - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Compute exclusive scan of simdgroup_sums - if (simd_group_id == 0) { - U prev_simdgroup = op.simd_exclusive_scan(simdgroup_sums[simd_lane_id]); - simdgroup_sums[simd_lane_id] = prev_simdgroup; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Compute the output - for (int i = 0; i < N_READS; i++) { - values[i] = op(values[i], prefix); - values[i] = op(values[i], simdgroup_sums[simd_group_id]); - values[i] = op(values[i], prev_thread); - } - - // Write the values - if (reverse) { - if (inclusive) { - if ((offset + N_READS) < axis_size) { - write_unsafe( - values, out + axis_size - offset - N_READS); - } else { - write_safe( - values, out + axis_size - offset - N_READS, offset, axis_size); - } - } else { - if (lid == 0 && offset == 0) { - out[axis_size - 1] = Op::init; - } - if ((offset + N_READS + 1) < axis_size) { - write_unsafe( - values, out + axis_size - offset - 1 - N_READS); - } else { - write_safe( - values, - out + axis_size - offset - 1 - N_READS, - offset + 1, - axis_size); - } - } - } else { - if (inclusive) { - if ((offset + N_READS) < axis_size) { - write_unsafe(values, out + offset); - } else { - write_safe( - values, out + offset, offset, axis_size); - } - } else { - if (lid == 0 && offset == 0) { - out[0] = Op::init; - } - if ((offset + N_READS + 1) < axis_size) { - write_unsafe(values, out + offset + 1); - } else { - write_safe( - values, out + offset + 1, offset + 1, axis_size); - } - } - } - - // Share the prefix - if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) { - simdgroup_sums[0] = values[N_READS - 1]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - prefix = simdgroup_sums[0]; - } -} - -template < - typename T, - typename U, - typename Op, - int N_READS, - bool inclusive, - bool reverse> -[[kernel]] void strided_scan( - const device T* in [[buffer(0)]], - device U* out [[buffer(1)]], - const constant size_t& axis_size [[buffer(2)]], - const constant size_t& stride [[buffer(3)]], - uint2 gid [[threadgroup_position_in_grid]], - uint2 lid [[thread_position_in_threadgroup]], - uint2 lsize [[threads_per_threadgroup]], - uint simd_size [[threads_per_simdgroup]]) { - Op op; - - // Allocate memory - threadgroup U read_buffer[N_READS * 32 * 32 + N_READS * 32]; - U values[N_READS]; - U prefix[N_READS]; - for (int i = 0; i < N_READS; i++) { - prefix[i] = Op::init; - } - - // Compute offsets - int offset = gid.y * axis_size * stride; - int global_index_x = gid.x * lsize.y * N_READS; - - for (uint j = 0; j < axis_size; j += simd_size) { - // Calculate the indices for the current thread - uint index_y = j + lid.y; - uint check_index_y = index_y; - uint index_x = global_index_x + lid.x * N_READS; - if (reverse) { - index_y = axis_size - 1 - index_y; - } - - // Read in SM - if (check_index_y < axis_size && (index_x + N_READS) < stride) { - for (int i = 0; i < N_READS; i++) { - read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = - in[offset + index_y * stride + index_x + i]; - } - } else { - for (int i = 0; i < N_READS; i++) { - if (check_index_y < axis_size && (index_x + i) < stride) { - read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = - in[offset + index_y * stride + index_x + i]; - } else { - read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = - Op::init; - } - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Read strided into registers - for (int i = 0; i < N_READS; i++) { - values[i] = - read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i]; - } - // Do we need the following barrier? Shouldn't all simd threads execute - // simultaneously? - simdgroup_barrier(mem_flags::mem_threadgroup); - - // Perform the scan - for (int i = 0; i < N_READS; i++) { - values[i] = op.simd_scan(values[i]); - values[i] = op(values[i], prefix[i]); - prefix[i] = simd_shuffle(values[i], simd_size - 1); - } - - // Write to SM - for (int i = 0; i < N_READS; i++) { - read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] = - values[i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Write to device memory - if (!inclusive) { - if (check_index_y == 0) { - if ((index_x + N_READS) < stride) { - for (int i = 0; i < N_READS; i++) { - out[offset + index_y * stride + index_x + i] = Op::init; - } - } else { - for (int i = 0; i < N_READS; i++) { - if ((index_x + i) < stride) { - out[offset + index_y * stride + index_x + i] = Op::init; - } - } - } - } - if (reverse) { - index_y -= 1; - check_index_y += 1; - } else { - index_y += 1; - check_index_y += 1; - } - } - if (check_index_y < axis_size && (index_x + N_READS) < stride) { - for (int i = 0; i < N_READS; i++) { - out[offset + index_y * stride + index_x + i] = - read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i]; - } - } else { - for (int i = 0; i < N_READS; i++) { - if (check_index_y < axis_size && (index_x + i) < stride) { - out[offset + index_y * stride + index_x + i] = - read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i]; - } - } - } - } -} +#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/scan.h" #define instantiate_contiguous_scan( \ name, itype, otype, op, inclusive, reverse, nreads) \ - template [[host_name("contiguous_scan_" #name)]] [[kernel]] void \ + template [[host_name("contig_scan_" #name)]] [[kernel]] void \ contiguous_scan, nreads, inclusive, reverse>( \ const device itype* in [[buffer(0)]], \ device otype* out [[buffer(1)]], \ @@ -474,7 +38,6 @@ template < uint2 lsize [[threads_per_threadgroup]], \ uint simd_size [[threads_per_simdgroup]]); -// clang-format off #define instantiate_scan_helper(name, itype, otype, op, nreads) \ instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \ instantiate_contiguous_scan(exclusive_##name, itype, otype, op, false, false, nreads) \ @@ -483,9 +46,8 @@ template < instantiate_strided_scan(inclusive_##name, itype, otype, op, true, false, nreads) \ instantiate_strided_scan(exclusive_##name, itype, otype, op, false, false, nreads) \ instantiate_strided_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \ - instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads) // clang-format on + instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads) -// clang-format off instantiate_scan_helper(sum_bool__int32, bool, int32_t, CumSum, 4) instantiate_scan_helper(sum_uint8_uint8, uint8_t, uint8_t, CumSum, 4) instantiate_scan_helper(sum_uint16_uint16, uint16_t, uint16_t, CumSum, 4) @@ -537,4 +99,4 @@ instantiate_scan_helper(min_int32_int32, int32_t, int32_t, CumMi instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4) instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4) instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4) -//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin) // clang-format on \ No newline at end of file +//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin) // clang-format on diff --git a/mlx/backend/metal/kernels/softmax.h b/mlx/backend/metal/kernels/softmax.h new file mode 100644 index 000000000..031ec128e --- /dev/null +++ b/mlx/backend/metal/kernels/softmax.h @@ -0,0 +1,190 @@ +// Copyright © 2023-2024 Apple Inc. + +template +inline T softmax_exp(T x) { + // Softmax doesn't need high precision exponential cause x is gonna be in + // (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)). + return fast::exp(x); +} + +template +[[kernel]] void softmax_single_row( + const device T* in, + device T* out, + constant int& axis_size, + uint gid [[threadgroup_position_in_grid]], + uint _lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + int lid = _lid; + + constexpr int SIMD_SIZE = 32; + + threadgroup AccT local_max[SIMD_SIZE]; + threadgroup AccT local_normalizer[SIMD_SIZE]; + + AccT ld[N_READS]; + + in += gid * axis_size + lid * N_READS; + if (lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + ld[i] = AccT(in[i]); + } + } else { + for (int i = 0; i < N_READS; i++) { + ld[i] = ((lid * N_READS + i) < axis_size) ? AccT(in[i]) + : Limits::finite_min; + } + } + if (simd_group_id == 0) { + local_max[simd_lane_id] = Limits::finite_min; + local_normalizer[simd_lane_id] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Get the max + AccT maxval = Limits::finite_min; + for (int i = 0; i < N_READS; i++) { + maxval = (maxval < ld[i]) ? ld[i] : maxval; + } + maxval = simd_max(maxval); + if (simd_lane_id == 0) { + local_max[simd_group_id] = maxval; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_group_id == 0) { + maxval = simd_max(local_max[simd_lane_id]); + if (simd_lane_id == 0) { + local_max[0] = maxval; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + maxval = local_max[0]; + + // Compute exp(x_i - maxval) and store the partial sums in local_normalizer + AccT normalizer = 0; + for (int i = 0; i < N_READS; i++) { + AccT exp_x = softmax_exp(ld[i] - maxval); + ld[i] = exp_x; + normalizer += exp_x; + } + normalizer = simd_sum(normalizer); + if (simd_lane_id == 0) { + local_normalizer[simd_group_id] = normalizer; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_group_id == 0) { + normalizer = simd_sum(local_normalizer[simd_lane_id]); + if (simd_lane_id == 0) { + local_normalizer[0] = normalizer; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + normalizer = 1 / local_normalizer[0]; + + // Normalize and write to the output + out += gid * axis_size + lid * N_READS; + if (lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + out[i] = T(ld[i] * normalizer); + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((lid * N_READS + i) < axis_size) { + out[i] = T(ld[i] * normalizer); + } + } + } +} + +template +[[kernel]] void softmax_looped( + const device T* in, + device T* out, + constant int& axis_size, + uint gid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + in += gid * axis_size; + + constexpr int SIMD_SIZE = 32; + + threadgroup AccT local_max[SIMD_SIZE]; + threadgroup AccT local_normalizer[SIMD_SIZE]; + + // Get the max and the normalizer in one go + AccT prevmax; + AccT maxval = Limits::finite_min; + AccT normalizer = 0; + for (int r = 0; r < static_cast(ceildiv(axis_size, N_READS * lsize)); + r++) { + int offset = r * lsize * N_READS + lid * N_READS; + AccT vals[N_READS]; + if (offset + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + vals[i] = AccT(in[offset + i]); + } + } else { + for (int i = 0; i < N_READS; i++) { + vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) + : Limits::finite_min; + } + } + prevmax = maxval; + for (int i = 0; i < N_READS; i++) { + maxval = (maxval < vals[i]) ? vals[i] : maxval; + } + normalizer *= softmax_exp(prevmax - maxval); + for (int i = 0; i < N_READS; i++) { + normalizer += softmax_exp(vals[i] - maxval); + } + } + // Now we got partial normalizer of N_READS * ceildiv(axis_size, N_READS * + // lsize) parts. We need to combine them. + // 1. We start by finding the max across simd groups + // 2. We then change the partial normalizers to account for a possible + // change in max + // 3. We sum all normalizers + prevmax = maxval; + maxval = simd_max(maxval); + normalizer *= softmax_exp(prevmax - maxval); + normalizer = simd_sum(normalizer); + + // Now the normalizer and max value is correct for each simdgroup. We write + // them shared memory and combine them. + prevmax = maxval; + if (simd_lane_id == 0) { + local_max[simd_group_id] = maxval; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + maxval = simd_max(local_max[simd_lane_id]); + normalizer *= softmax_exp(prevmax - maxval); + if (simd_lane_id == 0) { + local_normalizer[simd_group_id] = normalizer; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + normalizer = simd_sum(local_normalizer[simd_lane_id]); + normalizer = 1 / normalizer; + + // Finally given the normalizer and max value we can directly write the + // softmax output + out += gid * axis_size; + for (int r = 0; r < static_cast(ceildiv(axis_size, N_READS * lsize)); + r++) { + int offset = r * lsize * N_READS + lid * N_READS; + if (offset + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + out[offset + i] = T(softmax_exp(in[offset + i] - maxval) * normalizer); + } + } else { + for (int i = 0; i < N_READS; i++) { + if (offset + i < axis_size) { + out[offset + i] = + T(softmax_exp(in[offset + i] - maxval) * normalizer); + } + } + } + } +} diff --git a/mlx/backend/metal/kernels/softmax.metal b/mlx/backend/metal/kernels/softmax.metal index aaea6a279..34eea5e17 100644 --- a/mlx/backend/metal/kernels/softmax.metal +++ b/mlx/backend/metal/kernels/softmax.metal @@ -1,205 +1,18 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include #include +using namespace metal; + +// clang-format off #include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/utils.h" - -using namespace metal; - -template -inline T softmax_exp(T x) { - // Softmax doesn't need high precision exponential cause x is gonna be in - // (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)). - return fast::exp(x); -} - -template -[[kernel]] void softmax_single_row( - const device T* in, - device T* out, - constant int& axis_size, - uint gid [[threadgroup_position_in_grid]], - uint _lid [[thread_position_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - int lid = _lid; - - constexpr int SIMD_SIZE = 32; - - threadgroup AccT local_max[SIMD_SIZE]; - threadgroup AccT local_normalizer[SIMD_SIZE]; - - AccT ld[N_READS]; - - in += gid * axis_size + lid * N_READS; - if (lid * N_READS + N_READS <= axis_size) { - for (int i = 0; i < N_READS; i++) { - ld[i] = AccT(in[i]); - } - } else { - for (int i = 0; i < N_READS; i++) { - ld[i] = ((lid * N_READS + i) < axis_size) ? AccT(in[i]) - : Limits::finite_min; - } - } - if (simd_group_id == 0) { - local_max[simd_lane_id] = Limits::finite_min; - local_normalizer[simd_lane_id] = 0; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Get the max - AccT maxval = Limits::finite_min; - for (int i = 0; i < N_READS; i++) { - maxval = (maxval < ld[i]) ? ld[i] : maxval; - } - maxval = simd_max(maxval); - if (simd_lane_id == 0) { - local_max[simd_group_id] = maxval; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (simd_group_id == 0) { - maxval = simd_max(local_max[simd_lane_id]); - if (simd_lane_id == 0) { - local_max[0] = maxval; - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - maxval = local_max[0]; - - // Compute exp(x_i - maxval) and store the partial sums in local_normalizer - AccT normalizer = 0; - for (int i = 0; i < N_READS; i++) { - AccT exp_x = softmax_exp(ld[i] - maxval); - ld[i] = exp_x; - normalizer += exp_x; - } - normalizer = simd_sum(normalizer); - if (simd_lane_id == 0) { - local_normalizer[simd_group_id] = normalizer; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (simd_group_id == 0) { - normalizer = simd_sum(local_normalizer[simd_lane_id]); - if (simd_lane_id == 0) { - local_normalizer[0] = normalizer; - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - normalizer = 1 / local_normalizer[0]; - - // Normalize and write to the output - out += gid * axis_size + lid * N_READS; - if (lid * N_READS + N_READS <= axis_size) { - for (int i = 0; i < N_READS; i++) { - out[i] = T(ld[i] * normalizer); - } - } else { - for (int i = 0; i < N_READS; i++) { - if ((lid * N_READS + i) < axis_size) { - out[i] = T(ld[i] * normalizer); - } - } - } -} - -template -[[kernel]] void softmax_looped( - const device T* in, - device T* out, - constant int& axis_size, - uint gid [[threadgroup_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint lsize [[threads_per_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - in += gid * axis_size; - - constexpr int SIMD_SIZE = 32; - - threadgroup AccT local_max[SIMD_SIZE]; - threadgroup AccT local_normalizer[SIMD_SIZE]; - - // Get the max and the normalizer in one go - AccT prevmax; - AccT maxval = Limits::finite_min; - AccT normalizer = 0; - for (int r = 0; r < static_cast(ceildiv(axis_size, N_READS * lsize)); - r++) { - int offset = r * lsize * N_READS + lid * N_READS; - AccT vals[N_READS]; - if (offset + N_READS <= axis_size) { - for (int i = 0; i < N_READS; i++) { - vals[i] = AccT(in[offset + i]); - } - } else { - for (int i = 0; i < N_READS; i++) { - vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) - : Limits::finite_min; - } - } - prevmax = maxval; - for (int i = 0; i < N_READS; i++) { - maxval = (maxval < vals[i]) ? vals[i] : maxval; - } - normalizer *= softmax_exp(prevmax - maxval); - for (int i = 0; i < N_READS; i++) { - normalizer += softmax_exp(vals[i] - maxval); - } - } - // Now we got partial normalizer of N_READS * ceildiv(axis_size, N_READS * - // lsize) parts. We need to combine them. - // 1. We start by finding the max across simd groups - // 2. We then change the partial normalizers to account for a possible - // change in max - // 3. We sum all normalizers - prevmax = maxval; - maxval = simd_max(maxval); - normalizer *= softmax_exp(prevmax - maxval); - normalizer = simd_sum(normalizer); - - // Now the normalizer and max value is correct for each simdgroup. We write - // them shared memory and combine them. - prevmax = maxval; - if (simd_lane_id == 0) { - local_max[simd_group_id] = maxval; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - maxval = simd_max(local_max[simd_lane_id]); - normalizer *= softmax_exp(prevmax - maxval); - if (simd_lane_id == 0) { - local_normalizer[simd_group_id] = normalizer; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - normalizer = simd_sum(local_normalizer[simd_lane_id]); - normalizer = 1 / normalizer; - - // Finally given the normalizer and max value we can directly write the - // softmax output - out += gid * axis_size; - for (int r = 0; r < static_cast(ceildiv(axis_size, N_READS * lsize)); - r++) { - int offset = r * lsize * N_READS + lid * N_READS; - if (offset + N_READS <= axis_size) { - for (int i = 0; i < N_READS; i++) { - out[offset + i] = T(softmax_exp(in[offset + i] - maxval) * normalizer); - } - } else { - for (int i = 0; i < N_READS; i++) { - if (offset + i < axis_size) { - out[offset + i] = - T(softmax_exp(in[offset + i] - maxval) * normalizer); - } - } - } - } -} +#include "mlx/backend/metal/kernels/softmax.h" #define instantiate_softmax(name, itype) \ - template [[host_name("softmax_" #name)]] [[kernel]] void \ + template [[host_name("block_softmax_" #name)]] [[kernel]] void \ softmax_single_row( \ const device itype* in, \ device itype* out, \ @@ -208,7 +21,7 @@ template uint _lid [[thread_position_in_threadgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ - template [[host_name("softmax_looped_" #name)]] [[kernel]] void \ + template [[host_name("looped_softmax_" #name)]] [[kernel]] void \ softmax_looped( \ const device itype* in, \ device itype* out, \ @@ -220,7 +33,7 @@ template uint simd_group_id [[simdgroup_index_in_threadgroup]]); #define instantiate_softmax_precise(name, itype) \ - template [[host_name("softmax_precise_" #name)]] [[kernel]] void \ + template [[host_name("block_softmax_precise_" #name)]] [[kernel]] void \ softmax_single_row( \ const device itype* in, \ device itype* out, \ @@ -229,7 +42,7 @@ template uint _lid [[thread_position_in_threadgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ - template [[host_name("softmax_looped_precise_" #name)]] [[kernel]] void \ + template [[host_name("looped_softmax_precise_" #name)]] [[kernel]] void \ softmax_looped( \ const device itype* in, \ device itype* out, \ @@ -240,7 +53,6 @@ template uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]]); -// clang-format off instantiate_softmax(float32, float) instantiate_softmax(float16, half) instantiate_softmax(bfloat16, bfloat16_t) diff --git a/mlx/backend/metal/kernels/sort.h b/mlx/backend/metal/kernels/sort.h new file mode 100644 index 000000000..5b1cf2675 --- /dev/null +++ b/mlx/backend/metal/kernels/sort.h @@ -0,0 +1,674 @@ +// Copyright © 2023-2024 Apple Inc. + +#define MLX_MTL_CONST static constant constexpr const +#define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)") + +using namespace metal; + +// Based on GPU merge sort algorithm at +// https://github.com/NVIDIA/cccl/tree/main/cub/cub + +/////////////////////////////////////////////////////////////////////////////// +// Thread-level sort +/////////////////////////////////////////////////////////////////////////////// + +template +METAL_FUNC void thread_swap(thread T& a, thread T& b) { + T w = a; + a = b; + b = w; +} + +template +struct LessThan { + static constexpr constant T init = Limits::max; + + METAL_FUNC bool operator()(T a, T b) { + return a < b; + } +}; + +template < + typename val_t, + typename idx_t, + bool ARG_SORT, + short N_PER_THREAD, + typename CompareOp> +struct ThreadSort { + static METAL_FUNC void sort( + thread val_t (&vals)[N_PER_THREAD], + thread idx_t (&idxs)[N_PER_THREAD]) { + CompareOp op; + + MLX_MTL_LOOP_UNROLL + for (short i = 0; i < N_PER_THREAD; ++i) { + MLX_MTL_LOOP_UNROLL + for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) { + if (op(vals[j + 1], vals[j])) { + thread_swap(vals[j + 1], vals[j]); + thread_swap(idxs[j + 1], idxs[j]); + } + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// Threadgroup-level sort +/////////////////////////////////////////////////////////////////////////////// + +template < + typename val_t, + typename idx_t, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD, + typename CompareOp> +struct BlockMergeSort { + using thread_sort_t = + ThreadSort; + static METAL_FUNC int merge_partition( + const threadgroup val_t* As, + const threadgroup val_t* Bs, + short A_sz, + short B_sz, + short sort_md) { + CompareOp op; + + short A_st = max(0, sort_md - B_sz); + short A_ed = min(sort_md, A_sz); + + while (A_st < A_ed) { + short md = A_st + (A_ed - A_st) / 2; + auto a = As[md]; + auto b = Bs[sort_md - 1 - md]; + + if (op(b, a)) { + A_ed = md; + } else { + A_st = md + 1; + } + } + + return A_ed; + } + + static METAL_FUNC void merge_step( + const threadgroup val_t* As, + const threadgroup val_t* Bs, + const threadgroup idx_t* As_idx, + const threadgroup idx_t* Bs_idx, + short A_sz, + short B_sz, + thread val_t (&vals)[N_PER_THREAD], + thread idx_t (&idxs)[N_PER_THREAD]) { + CompareOp op; + short a_idx = 0; + short b_idx = 0; + + for (int i = 0; i < N_PER_THREAD; ++i) { + auto a = As[a_idx]; + auto b = Bs[b_idx]; + bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a)); + + vals[i] = pred ? b : a; + idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx]; + + b_idx += short(pred); + a_idx += short(!pred); + } + } + + static METAL_FUNC void sort( + threadgroup val_t* tgp_vals [[threadgroup(0)]], + threadgroup idx_t* tgp_idxs [[threadgroup(1)]], + int size_sorted_axis, + uint3 lid [[thread_position_in_threadgroup]]) { + // Get thread location + int idx = lid.x * N_PER_THREAD; + + // Load from shared memory + thread val_t thread_vals[N_PER_THREAD]; + thread idx_t thread_idxs[N_PER_THREAD]; + for (int i = 0; i < N_PER_THREAD; ++i) { + thread_vals[i] = tgp_vals[idx + i]; + if (ARG_SORT) { + thread_idxs[i] = tgp_idxs[idx + i]; + } + } + + // Per thread sort + if (idx < size_sorted_axis) { + thread_sort_t::sort(thread_vals, thread_idxs); + } + + // Do merges using threadgroup memory + for (int merge_threads = 2; merge_threads <= BLOCK_THREADS; + merge_threads *= 2) { + // Update threadgroup memory + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int i = 0; i < N_PER_THREAD; ++i) { + tgp_vals[idx + i] = thread_vals[i]; + if (ARG_SORT) { + tgp_idxs[idx + i] = thread_idxs[i]; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Find location in merge step + int merge_group = lid.x / merge_threads; + int merge_lane = lid.x % merge_threads; + + int sort_sz = N_PER_THREAD * merge_threads; + int sort_st = N_PER_THREAD * merge_threads * merge_group; + + // As = tgp_vals[A_st:A_ed] is sorted + // Bs = tgp_vals[B_st:B_ed] is sorted + int A_st = sort_st; + int A_ed = sort_st + sort_sz / 2; + int B_st = sort_st + sort_sz / 2; + int B_ed = sort_st + sort_sz; + + const threadgroup val_t* As = tgp_vals + A_st; + const threadgroup val_t* Bs = tgp_vals + B_st; + int A_sz = A_ed - A_st; + int B_sz = B_ed - B_st; + + // Find a partition of merge elements + // Ci = merge(As[partition:], Bs[sort_md - partition:]) + // of size N_PER_THREAD for each merge lane i + // C = [Ci] is sorted + int sort_md = N_PER_THREAD * merge_lane; + int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md); + + As += partition; + Bs += sort_md - partition; + + A_sz -= partition; + B_sz -= sort_md - partition; + + const threadgroup idx_t* As_idx = + ARG_SORT ? tgp_idxs + A_st + partition : nullptr; + const threadgroup idx_t* Bs_idx = + ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr; + + // Merge starting at the partition and store results in thread registers + merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs); + } + + // Write out to shared memory + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int i = 0; i < N_PER_THREAD; ++i) { + tgp_vals[idx + i] = thread_vals[i]; + if (ARG_SORT) { + tgp_idxs[idx + i] = thread_idxs[i]; + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// Kernel sort +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + typename U, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD, + typename CompareOp = LessThan> +struct KernelMergeSort { + using val_t = T; + using idx_t = uint; + using block_merge_sort_t = BlockMergeSort< + val_t, + idx_t, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD, + CompareOp>; + + MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD; + + static METAL_FUNC void block_sort( + const device T* inp, + device U* out, + const constant int& size_sorted_axis, + const constant int& stride_sorted_axis, + const constant int& stride_segment_axis, + threadgroup val_t* tgp_vals, + threadgroup idx_t* tgp_idxs, + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // tid.y tells us the segment index + inp += tid.y * stride_segment_axis; + out += tid.y * stride_segment_axis; + + // Copy into threadgroup memory + for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { + tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis] + : val_t(CompareOp::init); + if (ARG_SORT) { + tgp_idxs[i] = i; + } + } + + // Sort elements within the block + threadgroup_barrier(mem_flags::mem_threadgroup); + + block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write output + for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) { + if (ARG_SORT) { + out[i * stride_sorted_axis] = tgp_idxs[i]; + } else { + out[i * stride_sorted_axis] = tgp_vals[i]; + } + } + } +}; + +template < + typename T, + typename U, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD> +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort( + const device T* inp [[buffer(0)]], + device U* out [[buffer(1)]], + const constant int& size_sorted_axis [[buffer(2)]], + const constant int& stride_sorted_axis [[buffer(3)]], + const constant int& stride_segment_axis [[buffer(4)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + using sort_kernel = + KernelMergeSort; + using val_t = typename sort_kernel::val_t; + using idx_t = typename sort_kernel::idx_t; + + if (ARG_SORT) { + threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; + sort_kernel::block_sort( + inp, + out, + size_sorted_axis, + stride_sorted_axis, + stride_segment_axis, + tgp_vals, + tgp_idxs, + tid, + lid); + } else { + threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; + sort_kernel::block_sort( + inp, + out, + size_sorted_axis, + stride_sorted_axis, + stride_segment_axis, + tgp_vals, + nullptr, + tid, + lid); + } +} + +constant constexpr const int zero_helper = 0; + +template < + typename T, + typename U, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD> +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort_nc( + const device T* inp [[buffer(0)]], + device U* out [[buffer(1)]], + const constant int& size_sorted_axis [[buffer(2)]], + const constant int& stride_sorted_axis [[buffer(3)]], + const constant int& nc_dim [[buffer(4)]], + const device int* nc_shape [[buffer(5)]], + const device size_t* nc_strides [[buffer(6)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + using sort_kernel = + KernelMergeSort; + using val_t = typename sort_kernel::val_t; + using idx_t = typename sort_kernel::idx_t; + + auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim); + inp += block_idx; + out += block_idx; + + if (ARG_SORT) { + threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; + sort_kernel::block_sort( + inp, + out, + size_sorted_axis, + stride_sorted_axis, + zero_helper, + tgp_vals, + tgp_idxs, + tid, + lid); + } else { + threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; + sort_kernel::block_sort( + inp, + out, + size_sorted_axis, + stride_sorted_axis, + zero_helper, + tgp_vals, + nullptr, + tid, + lid); + } +} + +template < + typename val_t, + typename idx_t, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD, + typename CompareOp = LessThan> +struct KernelMultiBlockMergeSort { + using block_merge_sort_t = BlockMergeSort< + val_t, + idx_t, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD, + CompareOp>; + + MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD; + + static METAL_FUNC void block_sort( + const device val_t* inp, + device val_t* out_vals, + device idx_t* out_idxs, + const constant int& size_sorted_axis, + const constant int& stride_sorted_axis, + threadgroup val_t* tgp_vals, + threadgroup idx_t* tgp_idxs, + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // tid.y tells us the segment index + int base_idx = tid.x * N_PER_BLOCK; + + // Copy into threadgroup memory + for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { + int idx = base_idx + i; + tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis] + : val_t(CompareOp::init); + tgp_idxs[i] = idx; + } + + // Sort elements within the block + threadgroup_barrier(mem_flags::mem_threadgroup); + + block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write output + for (int i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { + int idx = base_idx + i; + if (idx < size_sorted_axis) { + out_vals[idx] = tgp_vals[i]; + out_idxs[idx] = tgp_idxs[i]; + } + } + } + + static METAL_FUNC int merge_partition( + const device val_t* As, + const device val_t* Bs, + int A_sz, + int B_sz, + int sort_md) { + CompareOp op; + + int A_st = max(0, sort_md - B_sz); + int A_ed = min(sort_md, A_sz); + + while (A_st < A_ed) { + int md = A_st + (A_ed - A_st) / 2; + auto a = As[md]; + auto b = Bs[sort_md - 1 - md]; + + if (op(b, a)) { + A_ed = md; + } else { + A_st = md + 1; + } + } + + return A_ed; + } +}; + +template < + typename val_t, + typename idx_t, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD> +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_sort( + const device val_t* inp [[buffer(0)]], + device val_t* out_vals [[buffer(1)]], + device idx_t* out_idxs [[buffer(2)]], + const constant int& size_sorted_axis [[buffer(3)]], + const constant int& stride_sorted_axis [[buffer(4)]], + const constant int& nc_dim [[buffer(5)]], + const device int* nc_shape [[buffer(6)]], + const device size_t* nc_strides [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + using sort_kernel = KernelMultiBlockMergeSort< + val_t, + idx_t, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD>; + + auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim); + inp += block_idx; + out_vals += tid.y * size_sorted_axis; + out_idxs += tid.y * size_sorted_axis; + + threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; + + sort_kernel::block_sort( + inp, + out_vals, + out_idxs, + size_sorted_axis, + stride_sorted_axis, + tgp_vals, + tgp_idxs, + tid, + lid); +} + +template < + typename val_t, + typename idx_t, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD> +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void +mb_block_partition( + device idx_t* block_partitions [[buffer(0)]], + const device val_t* dev_vals [[buffer(1)]], + const device idx_t* dev_idxs [[buffer(2)]], + const constant int& size_sorted_axis [[buffer(3)]], + const constant int& merge_tiles [[buffer(4)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 tgp_dims [[threads_per_threadgroup]]) { + using sort_kernel = KernelMultiBlockMergeSort< + val_t, + idx_t, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD>; + + block_partitions += tid.y * tgp_dims.x; + dev_vals += tid.y * size_sorted_axis; + dev_idxs += tid.y * size_sorted_axis; + + // Find location in merge step + int merge_group = lid.x / merge_tiles; + int merge_lane = lid.x % merge_tiles; + + int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles; + int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group; + + int A_st = min(size_sorted_axis, sort_st); + int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2); + int B_st = A_ed; + int B_ed = min(size_sorted_axis, B_st + sort_sz / 2); + + int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane); + int partition = sort_kernel::merge_partition( + dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at); + + block_partitions[lid.x] = A_st + partition; +} + +template < + typename val_t, + typename idx_t, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD, + typename CompareOp = LessThan> +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void +mb_block_merge( + const device idx_t* block_partitions [[buffer(0)]], + const device val_t* dev_vals_in [[buffer(1)]], + const device idx_t* dev_idxs_in [[buffer(2)]], + device val_t* dev_vals_out [[buffer(3)]], + device idx_t* dev_idxs_out [[buffer(4)]], + const constant int& size_sorted_axis [[buffer(5)]], + const constant int& merge_tiles [[buffer(6)]], + const constant int& num_tiles [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + using sort_kernel = KernelMultiBlockMergeSort< + val_t, + idx_t, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD, + CompareOp>; + + using block_sort_t = typename sort_kernel::block_merge_sort_t; + + block_partitions += tid.y * (num_tiles + 1); + dev_vals_in += tid.y * size_sorted_axis; + dev_idxs_in += tid.y * size_sorted_axis; + dev_vals_out += tid.y * size_sorted_axis; + dev_idxs_out += tid.y * size_sorted_axis; + + int block_idx = tid.x; + int merge_group = block_idx / merge_tiles; + int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group; + int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles; + int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st; + + int A_st = block_partitions[block_idx + 0]; + int A_ed = block_partitions[block_idx + 1]; + int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st); + int B_ed = min( + size_sorted_axis, + 2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed); + + if ((block_idx % merge_tiles) == merge_tiles - 1) { + A_ed = min(size_sorted_axis, sort_st + sort_sz / 2); + B_ed = min(size_sorted_axis, sort_st + sort_sz); + } + + int A_sz = A_ed - A_st; + int B_sz = B_ed - B_st; + + // Load from global memory + thread val_t thread_vals[N_PER_THREAD]; + thread idx_t thread_idxs[N_PER_THREAD]; + for (int i = 0; i < N_PER_THREAD; i++) { + int idx = BLOCK_THREADS * i + lid.x; + if (idx < (A_sz + B_sz)) { + thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx] + : dev_vals_in[B_st + idx - A_sz]; + thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx] + : dev_idxs_in[B_st + idx - A_sz]; + } else { + thread_vals[i] = CompareOp::init; + thread_idxs[i] = 0; + } + } + + // Write to shared memory + threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int i = 0; i < N_PER_THREAD; i++) { + int idx = BLOCK_THREADS * i + lid.x; + tgp_vals[idx] = thread_vals[i]; + tgp_idxs[idx] = thread_idxs[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Merge + int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x)); + + int A_st_local = block_sort_t::merge_partition( + tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local); + int A_ed_local = A_sz; + + int B_st_local = sort_md_local - A_st_local; + int B_ed_local = B_sz; + + int A_sz_local = A_ed_local - A_st_local; + int B_sz_local = B_ed_local - B_st_local; + + // Do merge + block_sort_t::merge_step( + tgp_vals + A_st_local, + tgp_vals + A_ed_local + B_st_local, + tgp_idxs + A_st_local, + tgp_idxs + A_ed_local + B_st_local, + A_sz_local, + B_sz_local, + thread_vals, + thread_idxs); + + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int i = 0; i < N_PER_THREAD; ++i) { + int idx = lid.x * N_PER_THREAD; + tgp_vals[idx + i] = thread_vals[i]; + tgp_idxs[idx + i] = thread_idxs[i]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + // Write output + int base_idx = tid.x * sort_kernel::N_PER_BLOCK; + for (int i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) { + int idx = base_idx + i; + if (idx < size_sorted_axis) { + dev_vals_out[idx] = tgp_vals[i]; + dev_idxs_out[idx] = tgp_idxs[i]; + } + } +} diff --git a/mlx/backend/metal/kernels/sort.metal b/mlx/backend/metal/kernels/sort.metal index b2d8ae815..ad13cc835 100644 --- a/mlx/backend/metal/kernels/sort.metal +++ b/mlx/backend/metal/kernels/sort.metal @@ -1,392 +1,16 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include +// clang-format off #include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/utils.h" - -#define MLX_MTL_CONST static constant constexpr const -#define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)") - -using namespace metal; - -// Based on GPU merge sort algorithm at -// https://github.com/NVIDIA/cccl/tree/main/cub/cub - -/////////////////////////////////////////////////////////////////////////////// -// Thread-level sort -/////////////////////////////////////////////////////////////////////////////// - -template -METAL_FUNC void thread_swap(thread T& a, thread T& b) { - T w = a; - a = b; - b = w; -} - -template -struct LessThan { - static constexpr constant T init = Limits::max; - - METAL_FUNC bool operator()(T a, T b) { - return a < b; - } -}; - -template < - typename val_t, - typename idx_t, - bool ARG_SORT, - short N_PER_THREAD, - typename CompareOp> -struct ThreadSort { - static METAL_FUNC void sort( - thread val_t (&vals)[N_PER_THREAD], - thread idx_t (&idxs)[N_PER_THREAD]) { - CompareOp op; - - MLX_MTL_LOOP_UNROLL - for (short i = 0; i < N_PER_THREAD; ++i) { - MLX_MTL_LOOP_UNROLL - for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) { - if (op(vals[j + 1], vals[j])) { - thread_swap(vals[j + 1], vals[j]); - thread_swap(idxs[j + 1], idxs[j]); - } - } - } - } -}; - -/////////////////////////////////////////////////////////////////////////////// -// Threadgroup-level sort -/////////////////////////////////////////////////////////////////////////////// - -template < - typename val_t, - typename idx_t, - bool ARG_SORT, - short BLOCK_THREADS, - short N_PER_THREAD, - typename CompareOp> -struct BlockMergeSort { - using thread_sort_t = - ThreadSort; - static METAL_FUNC int merge_partition( - const threadgroup val_t* As, - const threadgroup val_t* Bs, - short A_sz, - short B_sz, - short sort_md) { - CompareOp op; - - short A_st = max(0, sort_md - B_sz); - short A_ed = min(sort_md, A_sz); - - while (A_st < A_ed) { - short md = A_st + (A_ed - A_st) / 2; - auto a = As[md]; - auto b = Bs[sort_md - 1 - md]; - - if (op(b, a)) { - A_ed = md; - } else { - A_st = md + 1; - } - } - - return A_ed; - } - - static METAL_FUNC void merge_step( - const threadgroup val_t* As, - const threadgroup val_t* Bs, - const threadgroup idx_t* As_idx, - const threadgroup idx_t* Bs_idx, - short A_sz, - short B_sz, - thread val_t (&vals)[N_PER_THREAD], - thread idx_t (&idxs)[N_PER_THREAD]) { - CompareOp op; - short a_idx = 0; - short b_idx = 0; - - for (int i = 0; i < N_PER_THREAD; ++i) { - auto a = As[a_idx]; - auto b = Bs[b_idx]; - bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a)); - - vals[i] = pred ? b : a; - idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx]; - - b_idx += short(pred); - a_idx += short(!pred); - } - } - - static METAL_FUNC void sort( - threadgroup val_t* tgp_vals [[threadgroup(0)]], - threadgroup idx_t* tgp_idxs [[threadgroup(1)]], - int size_sorted_axis, - uint3 lid [[thread_position_in_threadgroup]]) { - // Get thread location - int idx = lid.x * N_PER_THREAD; - - // Load from shared memory - thread val_t thread_vals[N_PER_THREAD]; - thread idx_t thread_idxs[N_PER_THREAD]; - for (int i = 0; i < N_PER_THREAD; ++i) { - thread_vals[i] = tgp_vals[idx + i]; - if (ARG_SORT) { - thread_idxs[i] = tgp_idxs[idx + i]; - } - } - - // Per thread sort - if (idx < size_sorted_axis) { - thread_sort_t::sort(thread_vals, thread_idxs); - } - - // Do merges using threadgroup memory - for (int merge_threads = 2; merge_threads <= BLOCK_THREADS; - merge_threads *= 2) { - // Update threadgroup memory - threadgroup_barrier(mem_flags::mem_threadgroup); - for (int i = 0; i < N_PER_THREAD; ++i) { - tgp_vals[idx + i] = thread_vals[i]; - if (ARG_SORT) { - tgp_idxs[idx + i] = thread_idxs[i]; - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Find location in merge step - int merge_group = lid.x / merge_threads; - int merge_lane = lid.x % merge_threads; - - int sort_sz = N_PER_THREAD * merge_threads; - int sort_st = N_PER_THREAD * merge_threads * merge_group; - - // As = tgp_vals[A_st:A_ed] is sorted - // Bs = tgp_vals[B_st:B_ed] is sorted - int A_st = sort_st; - int A_ed = sort_st + sort_sz / 2; - int B_st = sort_st + sort_sz / 2; - int B_ed = sort_st + sort_sz; - - const threadgroup val_t* As = tgp_vals + A_st; - const threadgroup val_t* Bs = tgp_vals + B_st; - int A_sz = A_ed - A_st; - int B_sz = B_ed - B_st; - - // Find a partition of merge elements - // Ci = merge(As[partition:], Bs[sort_md - partition:]) - // of size N_PER_THREAD for each merge lane i - // C = [Ci] is sorted - int sort_md = N_PER_THREAD * merge_lane; - int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md); - - As += partition; - Bs += sort_md - partition; - - A_sz -= partition; - B_sz -= sort_md - partition; - - const threadgroup idx_t* As_idx = - ARG_SORT ? tgp_idxs + A_st + partition : nullptr; - const threadgroup idx_t* Bs_idx = - ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr; - - // Merge starting at the partition and store results in thread registers - merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs); - } - - // Write out to shared memory - threadgroup_barrier(mem_flags::mem_threadgroup); - for (int i = 0; i < N_PER_THREAD; ++i) { - tgp_vals[idx + i] = thread_vals[i]; - if (ARG_SORT) { - tgp_idxs[idx + i] = thread_idxs[i]; - } - } - } -}; - -/////////////////////////////////////////////////////////////////////////////// -// Kernel sort -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - typename U, - bool ARG_SORT, - short BLOCK_THREADS, - short N_PER_THREAD, - typename CompareOp = LessThan> -struct KernelMergeSort { - using val_t = T; - using idx_t = uint; - using block_merge_sort_t = BlockMergeSort< - val_t, - idx_t, - ARG_SORT, - BLOCK_THREADS, - N_PER_THREAD, - CompareOp>; - - MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD; - - static METAL_FUNC void block_sort( - const device T* inp, - device U* out, - const constant int& size_sorted_axis, - const constant int& stride_sorted_axis, - const constant int& stride_segment_axis, - threadgroup val_t* tgp_vals, - threadgroup idx_t* tgp_idxs, - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - // tid.y tells us the segment index - inp += tid.y * stride_segment_axis; - out += tid.y * stride_segment_axis; - - // Copy into threadgroup memory - for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { - tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis] - : val_t(CompareOp::init); - if (ARG_SORT) { - tgp_idxs[i] = i; - } - } - - // Sort elements within the block - threadgroup_barrier(mem_flags::mem_threadgroup); - - block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Write output - for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) { - if (ARG_SORT) { - out[i * stride_sorted_axis] = tgp_idxs[i]; - } else { - out[i * stride_sorted_axis] = tgp_vals[i]; - } - } - } -}; - -template < - typename T, - typename U, - bool ARG_SORT, - short BLOCK_THREADS, - short N_PER_THREAD> -[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort( - const device T* inp [[buffer(0)]], - device U* out [[buffer(1)]], - const constant int& size_sorted_axis [[buffer(2)]], - const constant int& stride_sorted_axis [[buffer(3)]], - const constant int& stride_segment_axis [[buffer(4)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - using sort_kernel = - KernelMergeSort; - using val_t = typename sort_kernel::val_t; - using idx_t = typename sort_kernel::idx_t; - - if (ARG_SORT) { - threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; - threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; - sort_kernel::block_sort( - inp, - out, - size_sorted_axis, - stride_sorted_axis, - stride_segment_axis, - tgp_vals, - tgp_idxs, - tid, - lid); - } else { - threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; - sort_kernel::block_sort( - inp, - out, - size_sorted_axis, - stride_sorted_axis, - stride_segment_axis, - tgp_vals, - nullptr, - tid, - lid); - } -} - -constant constexpr const int zero_helper = 0; - -template < - typename T, - typename U, - bool ARG_SORT, - short BLOCK_THREADS, - short N_PER_THREAD> -[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort_nc( - const device T* inp [[buffer(0)]], - device U* out [[buffer(1)]], - const constant int& size_sorted_axis [[buffer(2)]], - const constant int& stride_sorted_axis [[buffer(3)]], - const constant int& nc_dim [[buffer(4)]], - const device int* nc_shape [[buffer(5)]], - const device size_t* nc_strides [[buffer(6)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - using sort_kernel = - KernelMergeSort; - using val_t = typename sort_kernel::val_t; - using idx_t = typename sort_kernel::idx_t; - - auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim); - inp += block_idx; - out += block_idx; - - if (ARG_SORT) { - threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; - threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; - sort_kernel::block_sort( - inp, - out, - size_sorted_axis, - stride_sorted_axis, - zero_helper, - tgp_vals, - tgp_idxs, - tid, - lid); - } else { - threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; - sort_kernel::block_sort( - inp, - out, - size_sorted_axis, - stride_sorted_axis, - zero_helper, - tgp_vals, - nullptr, - tid, - lid); - } -} - -/////////////////////////////////////////////////////////////////////////////// -// Instantiations -/////////////////////////////////////////////////////////////////////////////// +#include "mlx/backend/metal/kernels/sort.h" #define instantiate_block_sort( \ name, itname, itype, otname, otype, arg_sort, bn, tn) \ - template [[host_name(#name "_" #itname "_" #otname "_bn" #bn \ + template [[host_name("c" #name "_" #itname "_" #otname "_bn" #bn \ "_tn" #tn)]] [[kernel]] void \ block_sort( \ const device itype* inp [[buffer(0)]], \ @@ -396,8 +20,8 @@ template < const constant int& stride_segment_axis [[buffer(4)]], \ uint3 tid [[threadgroup_position_in_grid]], \ uint3 lid [[thread_position_in_threadgroup]]); \ - template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn \ - "_nc")]] [[kernel]] void \ + template [[host_name("nc" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn \ + )]] [[kernel]] void \ block_sort_nc( \ const device itype* inp [[buffer(0)]], \ device otype* out [[buffer(1)]], \ @@ -411,22 +35,20 @@ template < #define instantiate_arg_block_sort_base(itname, itype, bn, tn) \ instantiate_block_sort( \ - arg_block_merge_sort, itname, itype, uint32, uint32_t, true, bn, tn) + arg_block_sort, itname, itype, uint32, uint32_t, true, bn, tn) #define instantiate_block_sort_base(itname, itype, bn, tn) \ instantiate_block_sort( \ - block_merge_sort, itname, itype, itname, itype, false, bn, tn) + _block_sort, itname, itype, itname, itype, false, bn, tn) -// clang-format off #define instantiate_block_sort_tn(itname, itype, bn) \ instantiate_block_sort_base(itname, itype, bn, 8) \ - instantiate_arg_block_sort_base(itname, itype, bn, 8) // clang-format on + instantiate_arg_block_sort_base(itname, itype, bn, 8) -// clang-format off #define instantiate_block_sort_bn(itname, itype) \ instantiate_block_sort_tn(itname, itype, 128) \ instantiate_block_sort_tn(itname, itype, 256) \ - instantiate_block_sort_tn(itname, itype, 512) + instantiate_block_sort_tn(itname, itype, 512) instantiate_block_sort_bn(uint8, uint8_t) instantiate_block_sort_bn(uint16, uint16_t) @@ -436,321 +58,18 @@ instantiate_block_sort_bn(int16, int16_t) instantiate_block_sort_bn(int32, int32_t) instantiate_block_sort_bn(float16, half) instantiate_block_sort_bn(float32, float) -instantiate_block_sort_bn(bfloat16, bfloat16_t) // clang-format on -// clang-format off +instantiate_block_sort_bn(bfloat16, bfloat16_t) + #define instantiate_block_sort_long(itname, itype) \ instantiate_block_sort_tn(itname, itype, 128) \ instantiate_block_sort_tn(itname, itype, 256) instantiate_block_sort_long(uint64, uint64_t) -instantiate_block_sort_long(int64, int64_t) // clang-format on - - /////////////////////////////////////////////////////////////////////////////// - // Multi block merge sort - /////////////////////////////////////////////////////////////////////////////// - - template < - typename val_t, - typename idx_t, - bool ARG_SORT, - short BLOCK_THREADS, - short N_PER_THREAD, - typename CompareOp = LessThan> - struct KernelMultiBlockMergeSort { - using block_merge_sort_t = BlockMergeSort< - val_t, - idx_t, - ARG_SORT, - BLOCK_THREADS, - N_PER_THREAD, - CompareOp>; - - MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD; - - static METAL_FUNC void block_sort( - const device val_t* inp, - device val_t* out_vals, - device idx_t* out_idxs, - const constant int& size_sorted_axis, - const constant int& stride_sorted_axis, - threadgroup val_t* tgp_vals, - threadgroup idx_t* tgp_idxs, - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - // tid.y tells us the segment index - int base_idx = tid.x * N_PER_BLOCK; - - // Copy into threadgroup memory - for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { - int idx = base_idx + i; - tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis] - : val_t(CompareOp::init); - tgp_idxs[i] = idx; - } - - // Sort elements within the block - threadgroup_barrier(mem_flags::mem_threadgroup); - - block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Write output - for (int i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { - int idx = base_idx + i; - if (idx < size_sorted_axis) { - out_vals[idx] = tgp_vals[i]; - out_idxs[idx] = tgp_idxs[i]; - } - } - } - - static METAL_FUNC int merge_partition( - const device val_t* As, - const device val_t* Bs, - int A_sz, - int B_sz, - int sort_md) { - CompareOp op; - - int A_st = max(0, sort_md - B_sz); - int A_ed = min(sort_md, A_sz); - - while (A_st < A_ed) { - int md = A_st + (A_ed - A_st) / 2; - auto a = As[md]; - auto b = Bs[sort_md - 1 - md]; - - if (op(b, a)) { - A_ed = md; - } else { - A_st = md + 1; - } - } - - return A_ed; - } -}; - -template < - typename val_t, - typename idx_t, - bool ARG_SORT, - short BLOCK_THREADS, - short N_PER_THREAD> -[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_sort( - const device val_t* inp [[buffer(0)]], - device val_t* out_vals [[buffer(1)]], - device idx_t* out_idxs [[buffer(2)]], - const constant int& size_sorted_axis [[buffer(3)]], - const constant int& stride_sorted_axis [[buffer(4)]], - const constant int& nc_dim [[buffer(5)]], - const device int* nc_shape [[buffer(6)]], - const device size_t* nc_strides [[buffer(7)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - using sort_kernel = KernelMultiBlockMergeSort< - val_t, - idx_t, - ARG_SORT, - BLOCK_THREADS, - N_PER_THREAD>; - - auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim); - inp += block_idx; - out_vals += tid.y * size_sorted_axis; - out_idxs += tid.y * size_sorted_axis; - - threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; - threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; - - sort_kernel::block_sort( - inp, - out_vals, - out_idxs, - size_sorted_axis, - stride_sorted_axis, - tgp_vals, - tgp_idxs, - tid, - lid); -} - -template < - typename val_t, - typename idx_t, - bool ARG_SORT, - short BLOCK_THREADS, - short N_PER_THREAD> -[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void -mb_block_partition( - device idx_t* block_partitions [[buffer(0)]], - const device val_t* dev_vals [[buffer(1)]], - const device idx_t* dev_idxs [[buffer(2)]], - const constant int& size_sorted_axis [[buffer(3)]], - const constant int& merge_tiles [[buffer(4)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint3 tgp_dims [[threads_per_threadgroup]]) { - using sort_kernel = KernelMultiBlockMergeSort< - val_t, - idx_t, - ARG_SORT, - BLOCK_THREADS, - N_PER_THREAD>; - - block_partitions += tid.y * tgp_dims.x; - dev_vals += tid.y * size_sorted_axis; - dev_idxs += tid.y * size_sorted_axis; - - // Find location in merge step - int merge_group = lid.x / merge_tiles; - int merge_lane = lid.x % merge_tiles; - - int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles; - int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group; - - int A_st = min(size_sorted_axis, sort_st); - int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2); - int B_st = A_ed; - int B_ed = min(size_sorted_axis, B_st + sort_sz / 2); - - int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane); - int partition = sort_kernel::merge_partition( - dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at); - - block_partitions[lid.x] = A_st + partition; -} - -template < - typename val_t, - typename idx_t, - bool ARG_SORT, - short BLOCK_THREADS, - short N_PER_THREAD, - typename CompareOp = LessThan> -[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void -mb_block_merge( - const device idx_t* block_partitions [[buffer(0)]], - const device val_t* dev_vals_in [[buffer(1)]], - const device idx_t* dev_idxs_in [[buffer(2)]], - device val_t* dev_vals_out [[buffer(3)]], - device idx_t* dev_idxs_out [[buffer(4)]], - const constant int& size_sorted_axis [[buffer(5)]], - const constant int& merge_tiles [[buffer(6)]], - const constant int& num_tiles [[buffer(7)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - using sort_kernel = KernelMultiBlockMergeSort< - val_t, - idx_t, - ARG_SORT, - BLOCK_THREADS, - N_PER_THREAD, - CompareOp>; - - using block_sort_t = typename sort_kernel::block_merge_sort_t; - - block_partitions += tid.y * (num_tiles + 1); - dev_vals_in += tid.y * size_sorted_axis; - dev_idxs_in += tid.y * size_sorted_axis; - dev_vals_out += tid.y * size_sorted_axis; - dev_idxs_out += tid.y * size_sorted_axis; - - int block_idx = tid.x; - int merge_group = block_idx / merge_tiles; - int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group; - int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles; - int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st; - - int A_st = block_partitions[block_idx + 0]; - int A_ed = block_partitions[block_idx + 1]; - int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st); - int B_ed = min( - size_sorted_axis, - 2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed); - - if ((block_idx % merge_tiles) == merge_tiles - 1) { - A_ed = min(size_sorted_axis, sort_st + sort_sz / 2); - B_ed = min(size_sorted_axis, sort_st + sort_sz); - } - - int A_sz = A_ed - A_st; - int B_sz = B_ed - B_st; - - // Load from global memory - thread val_t thread_vals[N_PER_THREAD]; - thread idx_t thread_idxs[N_PER_THREAD]; - for (int i = 0; i < N_PER_THREAD; i++) { - int idx = BLOCK_THREADS * i + lid.x; - if (idx < (A_sz + B_sz)) { - thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx] - : dev_vals_in[B_st + idx - A_sz]; - thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx] - : dev_idxs_in[B_st + idx - A_sz]; - } else { - thread_vals[i] = CompareOp::init; - thread_idxs[i] = 0; - } - } - - // Write to shared memory - threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; - threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; - threadgroup_barrier(mem_flags::mem_threadgroup); - for (int i = 0; i < N_PER_THREAD; i++) { - int idx = BLOCK_THREADS * i + lid.x; - tgp_vals[idx] = thread_vals[i]; - tgp_idxs[idx] = thread_idxs[i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Merge - int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x)); - - int A_st_local = block_sort_t::merge_partition( - tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local); - int A_ed_local = A_sz; - - int B_st_local = sort_md_local - A_st_local; - int B_ed_local = B_sz; - - int A_sz_local = A_ed_local - A_st_local; - int B_sz_local = B_ed_local - B_st_local; - - // Do merge - block_sort_t::merge_step( - tgp_vals + A_st_local, - tgp_vals + A_ed_local + B_st_local, - tgp_idxs + A_st_local, - tgp_idxs + A_ed_local + B_st_local, - A_sz_local, - B_sz_local, - thread_vals, - thread_idxs); - - threadgroup_barrier(mem_flags::mem_threadgroup); - for (int i = 0; i < N_PER_THREAD; ++i) { - int idx = lid.x * N_PER_THREAD; - tgp_vals[idx + i] = thread_vals[i]; - tgp_idxs[idx + i] = thread_idxs[i]; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - // Write output - int base_idx = tid.x * sort_kernel::N_PER_BLOCK; - for (int i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) { - int idx = base_idx + i; - if (idx < size_sorted_axis) { - dev_vals_out[idx] = tgp_vals[i]; - dev_idxs_out[idx] = tgp_idxs[i]; - } - } -} +instantiate_block_sort_long(int64, int64_t) #define instantiate_multi_block_sort( \ vtname, vtype, itname, itype, arg_sort, bn, tn) \ - template [[host_name("mb_block_sort_" #vtname "_" #itname "_bn" #bn \ + template [[host_name("sort_mbsort_" #vtname "_" #itname "_bn" #bn \ "_tn" #tn)]] [[kernel]] void \ mb_block_sort( \ const device vtype* inp [[buffer(0)]], \ @@ -763,7 +82,7 @@ mb_block_merge( const device size_t* nc_strides [[buffer(7)]], \ uint3 tid [[threadgroup_position_in_grid]], \ uint3 lid [[thread_position_in_threadgroup]]); \ - template [[host_name("mb_block_partition_" #vtname "_" #itname "_bn" #bn \ + template [[host_name("partition_mbsort_" #vtname "_" #itname "_bn" #bn \ "_tn" #tn)]] [[kernel]] void \ mb_block_partition( \ device itype * block_partitions [[buffer(0)]], \ @@ -774,7 +93,7 @@ mb_block_merge( uint3 tid [[threadgroup_position_in_grid]], \ uint3 lid [[thread_position_in_threadgroup]], \ uint3 tgp_dims [[threads_per_threadgroup]]); \ - template [[host_name("mb_block_merge_" #vtname "_" #itname "_bn" #bn \ + template [[host_name("merge_mbsort_" #vtname "_" #itname "_bn" #bn \ "_tn" #tn)]] [[kernel]] void \ mb_block_merge( \ const device itype* block_partitions [[buffer(0)]], \ @@ -788,7 +107,6 @@ mb_block_merge( uint3 tid [[threadgroup_position_in_grid]], \ uint3 lid [[thread_position_in_threadgroup]]); -// clang-format off #define instantiate_multi_block_sort_base(vtname, vtype) \ instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8) @@ -800,11 +118,10 @@ instantiate_multi_block_sort_base(int16, int16_t) instantiate_multi_block_sort_base(int32, int32_t) instantiate_multi_block_sort_base(float16, half) instantiate_multi_block_sort_base(float32, float) -instantiate_multi_block_sort_base(bfloat16, bfloat16_t) // clang-format on +instantiate_multi_block_sort_base(bfloat16, bfloat16_t) -// clang-format off #define instantiate_multi_block_sort_long(vtname, vtype) \ instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 8) instantiate_multi_block_sort_long(uint64, uint64_t) -instantiate_multi_block_sort_long(int64, int64_t) // clang-format on \ No newline at end of file +instantiate_multi_block_sort_long(int64, int64_t) // clang-format on diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index 8241b3006..4b3393fae 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -5,6 +5,7 @@ #include #include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/complex.h" +#include "mlx/backend/metal/kernels/defines.h" typedef half float16_t; diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index b8aede7e9..2dfd7a2ca 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -5,6 +5,13 @@ namespace mlx::core { +MTL::ComputePipelineState* get_arange_kernel( + metal::Device& d, + const std::string& kernel_name, + const array&) { + return d.get_kernel(kernel_name); +} + MTL::ComputePipelineState* get_unary_kernel( metal::Device& d, const std::string& kernel_name, @@ -15,31 +22,84 @@ MTL::ComputePipelineState* get_unary_kernel( MTL::ComputePipelineState* get_binary_kernel( metal::Device& d, const std::string& kernel_name, - const array& in, - const array& out) { + const array&, + const array&) { return d.get_kernel(kernel_name); } MTL::ComputePipelineState* get_binary_two_kernel( metal::Device& d, const std::string& kernel_name, - const array& in, - const array& out) { + const array&, + const array&) { return d.get_kernel(kernel_name); } MTL::ComputePipelineState* get_ternary_kernel( metal::Device& d, const std::string& kernel_name, - const array& out) { + const array&) { return d.get_kernel(kernel_name); } MTL::ComputePipelineState* get_copy_kernel( metal::Device& d, const std::string& kernel_name, - const array& in, - const array& out) { + const array&, + const array&) { + return d.get_kernel(kernel_name); +} + +MTL::ComputePipelineState* get_softmax_kernel( + metal::Device& d, + const std::string& kernel_name, + bool, + const array&) { + return d.get_kernel(kernel_name); +} + +MTL::ComputePipelineState* get_scan_kernel( + metal::Device& d, + const std::string& kernel_name, + bool, + bool, + const array&, + const array&) { + return d.get_kernel(kernel_name); +} + +MTL::ComputePipelineState* get_sort_kernel( + metal::Device& d, + const std::string& kernel_name, + const array&, + const array&, + int, + int) { + return d.get_kernel(kernel_name); +} + +MTL::ComputePipelineState* get_mb_sort_kernel( + metal::Device& d, + const std::string& kernel_name, + const array&, + const array&, + int, + int) { + return d.get_kernel(kernel_name); +} + +MTL::ComputePipelineState* get_reduce_init_kernel( + metal::Device& d, + const std::string& kernel_name, + const array&) { + return d.get_kernel(kernel_name); +} + +MTL::ComputePipelineState* get_reduce_kernel( + metal::Device& d, + const std::string& kernel_name, + const array&, + const array&) { return d.get_kernel(kernel_name); } diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index e1e7996c9..8d73edc48 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -6,6 +6,7 @@ #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" #include "mlx/utils.h" @@ -27,7 +28,7 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { } auto& s = stream(); auto& d = metal::device(s.device); - auto kernel = d.get_kernel("arange" + type_to_name(out)); + auto kernel = get_arange_kernel(d, "arange" + type_to_name(out), out); size_t nthreads = out.size(); MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); MTL::Size group_dims = MTL::Size( diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index f4a638bc4..73b166658 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -6,6 +6,7 @@ #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/reduce.h" #include "mlx/backend/metal/utils.h" @@ -40,9 +41,12 @@ void all_reduce_dispatch( const Stream& s) { Dtype out_dtype = out.dtype(); bool is_out_64b_int = is_64b_int(out_dtype); - auto kernel = (is_out_64b_int) - ? d.get_kernel("all_reduce_no_atomics_" + op_name + type_to_name(in)) - : d.get_kernel("all_reduce_" + op_name + type_to_name(in)); + std::string kernel_name = "all"; + if (is_out_64b_int) { + kernel_name += "NoAtomics"; + } + kernel_name += "_reduce_" + op_name + type_to_name(in); + auto kernel = get_reduce_kernel(d, kernel_name, in, out); compute_encoder->setComputePipelineState(kernel); @@ -158,18 +162,20 @@ void row_reduce_general_dispatch( bool is_med = non_row_reductions * reduction_size <= 256; is_out_64b_int &= !is_small && !is_med; - std::string small_desc = "_"; - if (is_small) { - small_desc = "_small_"; + std::string small_desc; + if (is_out_64b_int) { + small_desc = "NoAtomics"; + } else if (is_small) { + small_desc = "Small"; } else if (is_med) { - small_desc = "_med_"; + small_desc = "Med"; + } else { + small_desc = ""; } + kname << "rowGeneral" << small_desc << "_reduce_" << op_name + << type_to_name(in); - small_desc = is_out_64b_int ? "_no_atomics_" : small_desc; - - kname << "row_reduce_general" << small_desc << op_name << type_to_name(in); - - auto kernel = d.get_kernel(kname.str()); + auto kernel = get_reduce_kernel(d, kname.str(), in, out); compute_encoder->setComputePipelineState(kernel); // Get dispatch grid dims @@ -335,8 +341,8 @@ void strided_reduce_general_dispatch( // Specialize for small dims if (reduction_size * non_col_reductions < 16) { // Select kernel - auto kernel = - d.get_kernel("col_reduce_small_" + op_name + type_to_name(in)); + auto kernel = get_reduce_kernel( + d, "colSmall_reduce_" + op_name + type_to_name(in), in, out); compute_encoder->setComputePipelineState(kernel); // Select block dims @@ -373,10 +379,12 @@ void strided_reduce_general_dispatch( // Select kernel bool is_out_64b_int = is_64b_int(out_dtype); - auto kernel = (is_out_64b_int) - ? d.get_kernel( - "col_reduce_general_no_atomics_" + op_name + type_to_name(in)) - : d.get_kernel("col_reduce_general_" + op_name + type_to_name(in)); + std::string kernel_name = "colGeneral"; + if (is_out_64b_int) { + kernel_name += "NoAtomics"; + } + kernel_name += "_reduce_" + op_name + type_to_name(in); + auto kernel = get_reduce_kernel(d, kernel_name, in, out); compute_encoder->setComputePipelineState(kernel); @@ -490,9 +498,11 @@ void strided_reduce_general_dispatch( } ndim = new_shape.size(); - auto row_reduce_kernel = d.get_kernel( - "row_reduce_general_no_atomics_" + op_name + - type_to_name(intermediate)); + std::string kernel_name = + "rowGeneralNoAtomics_reduce_" + op_name + type_to_name(intermediate); + auto row_reduce_kernel = + get_reduce_kernel(d, kernel_name, intermediate, out); + compute_encoder->setComputePipelineState(row_reduce_kernel); compute_encoder.set_input_array(intermediate, 0); compute_encoder.set_output_array(out, 1); @@ -575,7 +585,8 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { auto& d = metal::device(s.device); auto& compute_encoder = d.get_command_encoder(s.index); { - auto kernel = d.get_kernel("i" + op_name + type_to_name(out)); + auto kernel = get_reduce_init_kernel( + d, "i_reduce_" + op_name + type_to_name(out), out); size_t nthreads = out.size(); MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp index 44c8fe5db..e35ecfbe1 100644 --- a/mlx/backend/metal/scan.cpp +++ b/mlx/backend/metal/scan.cpp @@ -5,6 +5,7 @@ #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" @@ -28,30 +29,33 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { in = arr_copy; } - std::ostringstream kname; - if (in.strides()[axis_] == 1) { - kname << "contiguous_scan_"; - if (reverse_) { - kname << "reverse_"; - } - kname << ((inclusive_) ? "inclusive_" : "exclusive_"); - switch (reduce_type_) { - case Scan::Sum: - kname << "sum_"; - break; - case Scan::Prod: - kname << "prod_"; - break; - case Scan::Max: - kname << "max_"; - break; - case Scan::Min: - kname << "min_"; - break; - } - kname << type_to_name(in) << "_" << type_to_name(out); + bool contiguous = in.strides()[axis_] == 1; - auto kernel = d.get_kernel(kname.str()); + std::ostringstream kname; + kname << (contiguous ? "contig_" : "strided_"); + kname << "scan_"; + if (reverse_) { + kname << "reverse_"; + } + kname << ((inclusive_) ? "inclusive_" : "exclusive_"); + switch (reduce_type_) { + case Scan::Sum: + kname << "sum_"; + break; + case Scan::Prod: + kname << "prod_"; + break; + case Scan::Max: + kname << "max_"; + break; + case Scan::Min: + kname << "min_"; + break; + } + kname << type_to_name(in) << "_" << type_to_name(out); + auto kernel = get_scan_kernel(d, kname.str(), reverse_, inclusive_, in, out); + + if (contiguous) { auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); compute_encoder.set_input_array(in, 0); @@ -79,28 +83,6 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); compute_encoder.dispatchThreads(grid_dims, group_dims); } else { - kname << "strided_scan_"; - if (reverse_) { - kname << "reverse_"; - } - kname << ((inclusive_) ? "inclusive_" : "exclusive_"); - switch (reduce_type_) { - case Scan::Sum: - kname << "sum_"; - break; - case Scan::Prod: - kname << "prod_"; - break; - case Scan::Max: - kname << "max_"; - break; - case Scan::Min: - kname << "min_"; - break; - } - kname << type_to_name(in) << "_" << type_to_name(out); - - auto kernel = d.get_kernel(kname.str()); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); compute_encoder.set_input_array(in, 0); diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index 41173dd16..22ab3cf4b 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -1,15 +1,18 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" namespace mlx::core { +constexpr int SOFTMAX_LOOPED_LIMIT = 4096; + void Softmax::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); if (!issubdtype(out.dtype(), floating)) { @@ -52,18 +55,17 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { const int simd_size = 32; const int n_reads = SOFTMAX_N_READS; const int looped_limit = SOFTMAX_LOOPED_LIMIT; - std::string op_name = "softmax_"; - if (axis_size > looped_limit) { - op_name += "looped_"; - } + + std::string kernel_name = (axis_size > looped_limit) ? "looped_" : "block_"; + kernel_name += "softmax_"; if (in.dtype() != float32 && precise_) { - op_name += "precise_"; + kernel_name += "precise_"; } - op_name += type_to_name(out); + kernel_name += type_to_name(out); + + auto kernel = get_softmax_kernel(d, kernel_name, precise_, out); auto& compute_encoder = d.get_command_encoder(s.index); { - auto kernel = d.get_kernel(op_name); - MTL::Size grid_dims, group_dims; if (axis_size <= looped_limit) { size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads; diff --git a/mlx/backend/metal/sort.cpp b/mlx/backend/metal/sort.cpp index 528f2e951..d594fccce 100644 --- a/mlx/backend/metal/sort.cpp +++ b/mlx/backend/metal/sort.cpp @@ -4,6 +4,7 @@ #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" @@ -11,7 +12,6 @@ namespace mlx::core { namespace { -template void single_block_sort( const Stream& s, metal::Device& d, @@ -19,7 +19,8 @@ void single_block_sort( array& out, int axis, int bn, - int tn) { + int tn, + bool argsort) { // Prepare shapes int n_rows = in.size() / in.shape(axis); @@ -46,19 +47,17 @@ void single_block_sort( // Prepare kernel name std::ostringstream kname; - if (ARGSORT) { - kname << "arg_"; + kname << (contiguous_write ? "c" : "nc"); + if (argsort) { + kname << "arg"; } - kname << "block_merge_sort_" << type_to_name(in) << "_" << type_to_name(out) - << "_bn" << bn << "_tn" << tn; - if (!contiguous_write) { - kname << "_nc"; - } + kname << "_block_sort_" << type_to_name(in) << "_" << type_to_name(out) + << "_bn" << bn << "_tn" << tn; + auto kernel = get_sort_kernel(d, kname.str(), in, out, bn, tn); // Prepare command encoder auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); // Set inputs @@ -81,7 +80,6 @@ void single_block_sort( compute_encoder.dispatchThreadgroups(grid_dims, group_dims); } -template void multi_block_sort( const Stream& s, metal::Device& d, @@ -90,7 +88,8 @@ void multi_block_sort( int axis, int bn, int tn, - int n_blocks) { + int n_blocks, + bool argsort) { // Prepare shapes int n_rows = in.size() / in.shape(axis); @@ -136,10 +135,10 @@ void multi_block_sort( // Do blockwise sort { std::ostringstream kname; - kname << "mb_block_sort_" << type_to_name(dev_vals_0) << "_" + kname << "sort_mbsort_" << type_to_name(dev_vals_0) << "_" << type_to_name(dev_idxs_0) << "_bn" << bn << "_tn" << tn; - - auto kernel = d.get_kernel(kname.str()); + auto kernel = + get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn); compute_encoder->setComputePipelineState(kernel); compute_encoder.set_input_array(in, 0); @@ -175,10 +174,11 @@ void multi_block_sort( // Do partition { std::ostringstream kname; - kname << "mb_block_partition_" << type_to_name(dev_vals_in) << "_" + kname << "partition_mbsort_" << type_to_name(dev_vals_in) << "_" << type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn; - auto kernel = d.get_kernel(kname.str()); + auto kernel = + get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn); compute_encoder->setComputePipelineState(kernel); compute_encoder.set_output_array(block_partitions, 0); @@ -196,10 +196,11 @@ void multi_block_sort( // Do merge { std::ostringstream kname; - kname << "mb_block_merge_" << type_to_name(dev_vals_in) << "_" + kname << "merge_mbsort_" << type_to_name(dev_vals_in) << "_" << type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn; - auto kernel = d.get_kernel(kname.str()); + auto kernel = + get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn); compute_encoder->setComputePipelineState(kernel); compute_encoder.set_input_array(block_partitions, 0); @@ -219,7 +220,7 @@ void multi_block_sort( } // Copy outputs with appropriate strides - array strided_out_arr = ARGSORT ? dev_idxs_out : dev_vals_out; + array strided_out_arr = argsort ? dev_idxs_out : dev_vals_out; if (axis == strided_out_arr.ndim() - 1) { copy_gpu_inplace(strided_out_arr, out, CopyType::Vector, s); @@ -252,13 +253,13 @@ void multi_block_sort( [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); } -template void gpu_merge_sort( const Stream& s, metal::Device& d, const array& in, array& out, - int axis_) { + int axis_, + bool argsort) { // Get size info int axis = axis_ < 0 ? axis_ + in.ndim() : axis_; int size_sorted_axis = in.shape(axis); @@ -284,9 +285,9 @@ void gpu_merge_sort( int n_blocks = (size_sorted_axis + n_per_block - 1) / n_per_block; if (n_blocks > 1) { - return multi_block_sort(s, d, in, out, axis, bn, tn, n_blocks); + return multi_block_sort(s, d, in, out, axis, bn, tn, n_blocks, argsort); } else { - return single_block_sort(s, d, in, out, axis, bn, tn); + return single_block_sort(s, d, in, out, axis, bn, tn, argsort); } } @@ -301,7 +302,7 @@ void ArgSort::eval_gpu(const std::vector& inputs, array& out) { auto& d = metal::device(s.device); auto& in = inputs[0]; - gpu_merge_sort(s, d, in, out, axis_); + gpu_merge_sort(s, d, in, out, axis_, true); } void Sort::eval_gpu(const std::vector& inputs, array& out) { @@ -313,7 +314,7 @@ void Sort::eval_gpu(const std::vector& inputs, array& out) { auto& d = metal::device(s.device); auto& in = inputs[0]; - gpu_merge_sort(s, d, in, out, axis_); + gpu_merge_sort(s, d, in, out, axis_, false); } void ArgPartition::eval_gpu(const std::vector& inputs, array& out) { @@ -326,7 +327,7 @@ void ArgPartition::eval_gpu(const std::vector& inputs, array& out) { auto& d = metal::device(s.device); auto& in = inputs[0]; - gpu_merge_sort(s, d, in, out, axis_); + gpu_merge_sort(s, d, in, out, axis_, true); } void Partition::eval_gpu(const std::vector& inputs, array& out) { @@ -339,7 +340,7 @@ void Partition::eval_gpu(const std::vector& inputs, array& out) { auto& d = metal::device(s.device); auto& in = inputs[0]; - gpu_merge_sort(s, d, in, out, axis_); + gpu_merge_sort(s, d, in, out, axis_, false); } } // namespace mlx::core diff --git a/mlx/primitives.h b/mlx/primitives.h index 45b573a02..517ce0b29 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1565,8 +1565,9 @@ class Reduce : public UnaryPrimitive { switch (reduce_type_) { case And: os << "And"; + break; case Or: - os << "And"; + os << "Or"; break; case Sum: os << "Sum"; @@ -1581,7 +1582,6 @@ class Reduce : public UnaryPrimitive { os << "Max"; break; } - os << " Reduce"; } bool is_equivalent(const Primitive& other) const override; @@ -1647,7 +1647,6 @@ class Scan : public UnaryPrimitive { os << "Max"; break; } - os << " Reduce"; } bool is_equivalent(const Primitive& other) const override;