From 28be4de7c29c188f58f77d9ca16cfd5ae8c29a82 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Wed, 28 Aug 2024 16:39:11 -0700 Subject: [PATCH] Fix JIT reductions (#1373) --- .circleci/config.yml | 6 + mlx/backend/metal/CMakeLists.txt | 1 + mlx/backend/metal/jit/reduce.h | 168 ------------------------- mlx/backend/metal/jit_kernels.cpp | 38 +++--- mlx/backend/metal/kernels/reduce.h | 1 + mlx/backend/metal/kernels/reduce.metal | 125 ++++-------------- mlx/backend/metal/reduce.cpp | 20 +-- 7 files changed, 63 insertions(+), 296 deletions(-) delete mode 100644 mlx/backend/metal/jit/reduce.h diff --git a/.circleci/config.yml b/.circleci/config.yml index 3e23c9f1f..bd207ae03 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -125,6 +125,12 @@ jobs: cd build/ cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel -DBUILD_SHARED_LIBS=ON -DMLX_BUILD_CPU=OFF -DMLX_BUILD_SAFETENSORS=OFF -DMLX_BUILD_GGUF=OFF -DMLX_METAL_JIT=ON make -j + - run: + name: Run Python tests with JIT + command: | + source env/bin/activate + CMAKE_BUILD_PARALLEL_LEVEL="" CMAKE_ARGS="-DMLX_METAL_JIT=ON" pip install -e . -v + LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu_jit build_release: parameters: diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index d337205da..5ab67aeb7 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -79,6 +79,7 @@ if (MLX_METAL_JIT) kernels/reduction/reduce_all.h kernels/reduction/reduce_col.h kernels/reduction/reduce_row.h + kernels/reduction/reduce_init.h ) make_jit_source( steel/gemm/gemm diff --git a/mlx/backend/metal/jit/reduce.h b/mlx/backend/metal/jit/reduce.h deleted file mode 100644 index ea2f5f4e0..000000000 --- a/mlx/backend/metal/jit/reduce.h +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright © 2024 Apple Inc. - -constexpr std::string_view reduce_init_kernels = R"( -[[kernel]] void {0}( - device {1}* out [[buffer(0)]], - uint tid [[thread_position_in_grid]]) {{ - out[tid] = {2}<{1}>::init; -}} -)"; - -constexpr std::string_view reduce_kernels = R"( -template [[host_name("all_{0}")]] [[kernel]] void -all_reduce<{1}, {2}, {3}<{2}>>( - const device {1}* in [[buffer(0)]], - device mlx_atomic<{2}>* out [[buffer(1)]], - const device size_t& in_size [[buffer(2)]], - uint gid [[thread_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint grid_size [[threads_per_grid]], - uint simd_per_group [[simdgroups_per_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]); -template [[host_name("colGeneral_{0}")]] [[kernel]] void -col_reduce_general<{1}, {2}, {3}<{2}>>( - const device {1}* in [[buffer(0)]], - device mlx_atomic<{2}>* out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& reduction_stride [[buffer(3)]], - const constant size_t& out_size [[buffer(4)]], - const constant int* shape [[buffer(5)]], - const constant size_t* strides [[buffer(6)]], - const constant int& ndim [[buffer(7)]], - threadgroup {2}* local_data [[threadgroup(0)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint3 lsize [[threads_per_threadgroup]]); -template [[host_name("colSmall_{0}")]] [[kernel]] void -col_reduce_small<{1}, {2}, {3}<{2}>>( - const device {1}* in [[buffer(0)]], - device {2}* out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& reduction_stride [[buffer(3)]], - const constant size_t& out_size [[buffer(4)]], - const constant int* shape [[buffer(5)]], - const constant size_t* strides [[buffer(6)]], - const constant int& ndim [[buffer(7)]], - const constant size_t& non_col_reductions [[buffer(8)]], - const constant int* non_col_shapes [[buffer(9)]], - const constant size_t* non_col_strides [[buffer(10)]], - const constant int& non_col_ndim [[buffer(11)]], - uint tid [[thread_position_in_grid]]); -template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void -row_reduce_general_small<{1}, {2}, {3}<{2}>>( - const device {1}* in [[buffer(0)]], - device {2}* out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& out_size [[buffer(3)]], - const constant size_t& non_row_reductions [[buffer(4)]], - const constant int* shape [[buffer(5)]], - const constant size_t* strides [[buffer(6)]], - const constant int& ndim [[buffer(7)]], - uint lid [[thread_position_in_grid]]); -template [[host_name("rowGeneralMed_{0}")]] [[kernel]] void -row_reduce_general_med<{1}, {2}, {3}<{2}>>( - const device {1}* in [[buffer(0)]], - device {2}* out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& out_size [[buffer(3)]], - const constant size_t& non_row_reductions [[buffer(4)]], - const constant int* shape [[buffer(5)]], - const constant size_t* strides [[buffer(6)]], - const constant int& ndim [[buffer(7)]], - uint tid [[threadgroup_position_in_grid]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]); -template [[host_name("rowGeneral_{0}")]] [[kernel]] void -row_reduce_general<{1}, {2}, {3}<{2}>>( - const device {1}* in [[buffer(0)]], - device mlx_atomic<{2}>* out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& out_size [[buffer(3)]], - const constant size_t& non_row_reductions [[buffer(4)]], - const constant int* shape [[buffer(5)]], - const constant size_t* strides [[buffer(6)]], - const constant int& ndim [[buffer(7)]], - uint3 lid [[thread_position_in_threadgroup]], - uint3 lsize [[threads_per_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_per_group [[simdgroups_per_threadgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]); -)"; - -constexpr std::string_view reduce_non_atomic_kernels = R"( -template [[host_name("allNoAtomics_{0}")]] [[kernel]] void -all_reduce_no_atomics<{1}, {2}, {3}<{2}>>( - const device {1}* in [[buffer(0)]], - device {2}* out [[buffer(1)]], - const device size_t& in_size [[buffer(2)]], - uint gid [[thread_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint grid_size [[threads_per_grid]], - uint simd_per_group [[simdgroups_per_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint thread_group_id [[threadgroup_position_in_grid]]); - -template [[host_name("colGeneralNoAtomics_{0}")]] [[kernel]] void - col_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>( - const device {1}* in [[buffer(0)]], - device {2}* out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& reduction_stride [[buffer(3)]], - const constant size_t& out_size [[buffer(4)]], - const constant int* shape [[buffer(5)]], - const constant size_t* strides [[buffer(6)]], - const constant int& ndim [[buffer(7)]], - threadgroup {2}* local_data [[threadgroup(0)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint3 gid [[thread_position_in_grid]], - uint3 lsize [[threads_per_threadgroup]], - uint3 gsize [[threads_per_grid]]); -template [[host_name("colSmall_{0}")]] [[kernel]] void -col_reduce_small<{1}, {2}, {3}<{2}>>( - const device {1}* in [[buffer(0)]], - device {2}* out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& reduction_stride [[buffer(3)]], - const constant size_t& out_size [[buffer(4)]], - const constant int* shape [[buffer(5)]], - const constant size_t* strides [[buffer(6)]], - const constant int& ndim [[buffer(7)]], - const constant size_t& non_col_reductions [[buffer(8)]], - const constant int* non_col_shapes [[buffer(9)]], - const constant size_t* non_col_strides [[buffer(10)]], - const constant int& non_col_ndim [[buffer(11)]], - uint tid [[thread_position_in_grid]]); -template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void -row_reduce_general_small<{1}, {2}, {3}<{2}>>( - const device {1}* in [[buffer(0)]], - device {2}* out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& out_size [[buffer(3)]], - const constant size_t& non_row_reductions [[buffer(4)]], - const constant int* shape [[buffer(5)]], - const constant size_t* strides [[buffer(6)]], - const constant int& ndim [[buffer(7)]], - uint lid [[thread_position_in_grid]]); -template [[host_name("rowGeneralNoAtomics_{0}")]] [[kernel]] void -row_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>( - const device {1}* in [[buffer(0)]], - device {2}* out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& out_size [[buffer(3)]], - const constant size_t& non_row_reductions [[buffer(4)]], - const constant int* shape [[buffer(5)]], - const constant size_t* strides [[buffer(6)]], - const constant int& ndim [[buffer(7)]], - uint3 lid [[thread_position_in_threadgroup]], - uint3 lsize [[threads_per_threadgroup]], - uint3 gsize [[threads_per_grid]], - uint3 tid [[threadgroup_position_in_grid]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_per_group [[simdgroups_per_threadgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]); -)"; diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 895ad143d..5a2a4e88e 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -6,7 +6,6 @@ #include "mlx/backend/metal/jit/copy.h" #include "mlx/backend/metal/jit/gemv_masked.h" #include "mlx/backend/metal/jit/includes.h" -#include "mlx/backend/metal/jit/reduce.h" #include "mlx/backend/metal/jit/scan.h" #include "mlx/backend/metal/jit/softmax.h" #include "mlx/backend/metal/jit/steel_conv.h" @@ -323,12 +322,13 @@ MTL::ComputePipelineState* get_reduce_init_kernel( auto lib = d.get_library(kernel_name); if (lib == nullptr) { std::ostringstream kernel_source; - kernel_source << metal::utils() << metal::reduce_utils() - << fmt::format( - reduce_init_kernels, - kernel_name, - get_type_string(out.dtype()), - op_name(out)); + std::string op_type = op_name(out); + op_type[0] = std::toupper(op_name(out)[0]); + auto out_type = get_type_string(out.dtype()); + std::string op = op_type + "<" + out_type + ">"; + kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce(); + kernel_source << get_template_definition( + kernel_name, "init_reduce", out_type, op); lib = d.get_library(kernel_name, kernel_source.str()); } return d.get_kernel(kernel_name, lib); @@ -347,14 +347,22 @@ MTL::ComputePipelineState* get_reduce_kernel( op_type[0] = std::toupper(op_name[0]); bool non_atomic = out.dtype() == int64 || out.dtype() == uint64; std::ostringstream kernel_source; - kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce() - << fmt::format( - non_atomic ? reduce_non_atomic_kernels - : reduce_kernels, - lib_name, - get_type_string(in.dtype()), - get_type_string(out.dtype()), - op_type); + auto in_type = get_type_string(in.dtype()); + auto out_type = get_type_string(out.dtype()); + std::vector> reduce_kernels = { + {"all_reduce", "allReduce"}, + {"col_reduce_small", "colReduceSmall"}, + {"col_reduce_looped", "colReduceLooped"}, + {"row_reduce_small", "rowReduceSmall"}, + {"row_reduce_looped", "rowReduceLooped"}, + {"row_reduce_simple", "rowReduceSimple"}}; + std::string op = op_type + "<" + out_type + ">"; + kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce(); + for (auto [func, name] : reduce_kernels) { + kernel_source << get_template_definition( + name + "_" + lib_name, func, in_type, out_type, op); + } + lib = d.get_library(lib_name, kernel_source.str()); } return d.get_kernel(kernel_name, lib); diff --git a/mlx/backend/metal/kernels/reduce.h b/mlx/backend/metal/kernels/reduce.h index 649db0adb..ee5c3d5c0 100644 --- a/mlx/backend/metal/kernels/reduce.h +++ b/mlx/backend/metal/kernels/reduce.h @@ -1,4 +1,5 @@ #pragma once #include "mlx/backend/metal/kernels/reduction/reduce_all.h" #include "mlx/backend/metal/kernels/reduction/reduce_col.h" +#include "mlx/backend/metal/kernels/reduction/reduce_init.h" #include "mlx/backend/metal/kernels/reduction/reduce_row.h" diff --git a/mlx/backend/metal/kernels/reduce.metal b/mlx/backend/metal/kernels/reduce.metal index 3d7e92c52..acfe6b64f 100644 --- a/mlx/backend/metal/kernels/reduce.metal +++ b/mlx/backend/metal/kernels/reduce.metal @@ -8,7 +8,6 @@ #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/atomic.h" #include "mlx/backend/metal/kernels/reduction/ops.h" -#include "mlx/backend/metal/kernels/reduction/reduce_init.h" #include "mlx/backend/metal/kernels/reduce.h" #define instantiate_reduce_helper_floats(inst_f, name, op) \ @@ -84,9 +83,9 @@ op) #define instantiate_init_reduce(name, otype, op) \ - template [[host_name("i_reduce_" #name)]] [[kernel]] void \ - init_reduce( \ - device otype * out [[buffer(1)]], uint tid [[thread_position_in_grid]]); + instantiate_kernel("init_reduce_" #name, \ + init_reduce, \ + otype, op) #define instantiate_init_reduce_helper(name, tname, type, op) \ instantiate_init_reduce(name##tname, type, op) @@ -98,18 +97,9 @@ instantiate_init_reduce(andbool_, bool, And) instantiate_init_reduce(orbool_, bool, Or) #define instantiate_all_reduce(name, itype, otype, op) \ - template [[host_name("all_reduce_" #name)]] \ - [[kernel]] void all_reduce( \ - const device itype* in [[buffer(0)]], \ - device otype* out [[buffer(1)]], \ - const constant size_t& in_size [[buffer(2)]], \ - const constant size_t& row_size [[buffer(3)]], \ - uint3 gid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint3 lsize [[threads_per_threadgroup]], \ - uint simd_per_group [[simdgroups_per_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); + instantiate_kernel("allReduce_" #name, \ + all_reduce, \ + itype, otype, op) #define instantiate_same_all_reduce_helper(name, tname, type, op) \ instantiate_all_reduce(name##tname, type, type, op) @@ -124,44 +114,14 @@ instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or) instantiate_all_reduce(sumbool_, bool, uint32_t, Sum) #define instantiate_col_reduce_small(name, itype, otype, op, dim) \ - template [[host_name("colSmall" #dim "_reduce_" #name)]] \ - [[kernel]] void col_reduce_small( \ - const device itype* in [[buffer(0)]], \ - device otype* out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& reduction_stride [[buffer(3)]], \ - const constant int* shape [[buffer(4)]], \ - const constant size_t* strides [[buffer(5)]], \ - const constant int& ndim [[buffer(6)]], \ - const constant int* reduce_shape [[buffer(7)]], \ - const constant size_t* reduce_strides [[buffer(8)]], \ - const constant int& reduce_ndim [[buffer(9)]], \ - const constant size_t& non_col_reductions [[buffer(10)]], \ - uint3 gid [[threadgroup_position_in_grid]], \ - uint3 gsize [[threadgroups_per_grid]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ - uint3 tid [[thread_position_in_grid]], \ - uint3 tsize [[threads_per_grid]]); + instantiate_kernel("colReduceSmall_" #dim "_reduce_" #name, \ + col_reduce_small, \ + itype, otype, op, dim) #define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \ - template [[host_name("colLooped" #dim "_" #bm "_" #bn "_reduce_" #name)]] \ - [[kernel]] void col_reduce_looped( \ - const device itype* in [[buffer(0)]], \ - device otype* out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& reduction_stride [[buffer(3)]], \ - const constant int* shape [[buffer(4)]], \ - const constant size_t* strides [[buffer(5)]], \ - const constant int& ndim [[buffer(6)]], \ - const constant int* reduce_shape [[buffer(7)]], \ - const constant size_t* reduce_strides [[buffer(8)]], \ - const constant int& reduce_ndim [[buffer(9)]], \ - const constant size_t& non_col_reductions [[buffer(10)]], \ - uint3 gid [[threadgroup_position_in_grid]], \ - uint3 gsize [[threadgroups_per_grid]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); + instantiate_kernel("colReduceLooped_" #dim "_" #bm "_" #bn "_reduce_" #name, \ + col_reduce_looped, \ + itype, otype, op, dim, bm, bn) #define instantiate_col_reduce_looped(name, itype, otype, op, dim) \ instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 8, 128) \ @@ -190,44 +150,14 @@ instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And) #define instantiate_row_reduce_small(name, itype, otype, op, dim) \ - template [[host_name("rowSmall" #dim "_reduce_" #name)]] [[kernel]] void \ - row_reduce_small( \ - const device itype* in [[buffer(0)]], \ - device otype* out [[buffer(1)]], \ - const constant size_t& row_size [[buffer(2)]], \ - const constant size_t& non_row_reductions [[buffer(3)]], \ - const constant int* shape [[buffer(4)]], \ - const constant size_t* strides [[buffer(5)]], \ - const constant int& ndim [[buffer(6)]], \ - const constant int* reduce_shape [[buffer(7)]], \ - const constant size_t* reduce_strides [[buffer(8)]], \ - const constant int& reduce_ndim [[buffer(9)]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint3 gid [[threadgroup_position_in_grid]], \ - uint3 gsize [[threadgroups_per_grid]], \ - uint3 tid [[thread_position_in_grid]], \ - uint3 tsize [[threads_per_grid]]); + instantiate_kernel("rowReduceSmall_" #dim "_reduce_" #name, \ + row_reduce_small, \ + itype, otype, op, dim) + #define instantiate_row_reduce_looped(name, itype, otype, op, dim) \ - template \ - [[host_name("rowLooped" #dim "_reduce_" #name)]] [[kernel]] void \ - row_reduce_looped( \ - const device itype* in [[buffer(0)]], \ - device otype* out [[buffer(1)]], \ - const constant size_t& row_size [[buffer(2)]], \ - const constant size_t& non_row_reductions [[buffer(3)]], \ - const constant int* shape [[buffer(4)]], \ - const constant size_t* strides [[buffer(5)]], \ - const constant int& ndim [[buffer(6)]], \ - const constant int* reduce_shape [[buffer(7)]], \ - const constant size_t* reduce_strides [[buffer(8)]], \ - const constant int& reduce_ndim [[buffer(9)]], \ - uint3 gid [[threadgroup_position_in_grid]], \ - uint3 gsize [[threadgroups_per_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint3 lsize [[threads_per_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_per_group [[simdgroups_per_threadgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); + instantiate_kernel("rowReduceLooped_" #dim "_reduce_" #name, \ + row_reduce_looped, \ + itype, otype, op, dim) #define instantiate_row_reduce_general(name, itype, otype, op) \ instantiate_row_reduce_small(name, itype, otype, op, 0) \ @@ -240,20 +170,9 @@ instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or instantiate_row_reduce_looped(name, itype, otype, op, 2) \ instantiate_row_reduce_looped(name, itype, otype, op, 3) \ instantiate_row_reduce_looped(name, itype, otype, op, 4) \ - template \ - [[host_name("rowSimple_reduce_" #name)]] [[kernel]] void \ - row_reduce_simple( \ - const device itype* in [[buffer(0)]], \ - device otype* out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& out_size [[buffer(3)]], \ - uint3 gid [[threadgroup_position_in_grid]], \ - uint3 gsize [[threadgroups_per_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint3 lsize [[threads_per_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_per_group [[simdgroups_per_threadgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ + instantiate_kernel("rowReduceSimple_" #name, \ + row_reduce_simple, \ + itype, otype, op) #define instantiate_same_row_reduce_helper(name, tname, type, op) \ instantiate_row_reduce_general(name##tname, type, type, op) diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index 95f6e2eff..72f15d802 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -231,8 +231,8 @@ void init_reduce( CommandEncoder& compute_encoder, metal::Device& d, const Stream& s) { - auto kernel = - get_reduce_init_kernel(d, "i_reduce_" + op_name + type_to_name(out), out); + auto kernel = get_reduce_init_kernel( + d, "init_reduce_" + op_name + type_to_name(out), out); size_t nthreads = out.size(); MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); @@ -255,7 +255,7 @@ void all_reduce_dispatch( std::vector& copies) { // Set the kernel std::ostringstream kname; - kname << "all_reduce_" << op_name << type_to_name(in); + kname << "allReduce_" << op_name << type_to_name(in); auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out); compute_encoder->setComputePipelineState(kernel); @@ -309,7 +309,7 @@ void all_reduce_dispatch( // 2nd pass std::ostringstream kname_2nd_pass; - kname_2nd_pass << "all_reduce_" << op_name << type_to_name(intermediate); + kname_2nd_pass << "allReduce_" << op_name << type_to_name(intermediate); auto kernel_2nd_pass = get_reduce_kernel(d, kname_2nd_pass.str(), op_name, intermediate, out); compute_encoder->setComputePipelineState(kernel_2nd_pass); @@ -335,7 +335,7 @@ void row_reduce_small( // Set the kernel std::ostringstream kname; int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0; - kname << "rowSmall" << n << "_reduce_" << op_name << type_to_name(in); + kname << "rowReduceSmall_" << n << "_reduce_" << op_name << type_to_name(in); auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out); compute_encoder->setComputePipelineState(kernel); @@ -370,7 +370,7 @@ void row_reduce_simple( const Stream& s) { // Set the kernel std::ostringstream kname; - kname << "rowSimple_reduce_" << op_name << type_to_name(in); + kname << "rowReduceSimple_" << op_name << type_to_name(in); auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out); compute_encoder->setComputePipelineState(kernel); @@ -407,7 +407,7 @@ void row_reduce_looped( // Set the kernel std::ostringstream kname; int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0; - kname << "rowLooped" << n << "_reduce_" << op_name << type_to_name(in); + kname << "rowReduceLooped_" << n << "_reduce_" << op_name << type_to_name(in); auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out); compute_encoder->setComputePipelineState(kernel); @@ -497,7 +497,7 @@ void strided_reduce_small( // Set the kernel int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0; std::ostringstream kname; - kname << "colSmall" << n << "_reduce_" << op_name << type_to_name(in); + kname << "colReduceSmall_" << n << "_reduce_" << op_name << type_to_name(in); auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out); compute_encoder->setComputePipelineState(kernel); @@ -535,8 +535,8 @@ void strided_reduce_looped( // Set the kernel int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0; std::ostringstream kname; - kname << "colLooped" << n << "_" << BM << "_" << BN << "_reduce_" << op_name - << type_to_name(in); + kname << "colReduceLooped_" << n << "_" << BM << "_" << BN << "_reduce_" + << op_name << type_to_name(in); auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out); compute_encoder->setComputePipelineState(kernel);