Fix JIT reductions (#1373)

This commit is contained in:
Alex Barron 2024-08-28 16:39:11 -07:00 committed by GitHub
parent a6c3b38fba
commit 28be4de7c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 63 additions and 296 deletions

View File

@ -125,6 +125,12 @@ jobs:
cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel -DBUILD_SHARED_LIBS=ON -DMLX_BUILD_CPU=OFF -DMLX_BUILD_SAFETENSORS=OFF -DMLX_BUILD_GGUF=OFF -DMLX_METAL_JIT=ON
make -j
- run:
name: Run Python tests with JIT
command: |
source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL="" CMAKE_ARGS="-DMLX_METAL_JIT=ON" pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
build_release:
parameters:

View File

@ -79,6 +79,7 @@ if (MLX_METAL_JIT)
kernels/reduction/reduce_all.h
kernels/reduction/reduce_col.h
kernels/reduction/reduce_row.h
kernels/reduction/reduce_init.h
)
make_jit_source(
steel/gemm/gemm

View File

@ -1,168 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view reduce_init_kernels = R"(
[[kernel]] void {0}(
device {1}* out [[buffer(0)]],
uint tid [[thread_position_in_grid]]) {{
out[tid] = {2}<{1}>::init;
}}
)";
constexpr std::string_view reduce_kernels = R"(
template [[host_name("all_{0}")]] [[kernel]] void
all_reduce<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device mlx_atomic<{2}>* out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("colGeneral_{0}")]] [[kernel]] void
col_reduce_general<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device mlx_atomic<{2}>* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup {2}* local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]);
template [[host_name("colSmall_{0}")]] [[kernel]] void
col_reduce_small<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
const constant size_t& non_col_reductions [[buffer(8)]],
const constant int* non_col_shapes [[buffer(9)]],
const constant size_t* non_col_strides [[buffer(10)]],
const constant int& non_col_ndim [[buffer(11)]],
uint tid [[thread_position_in_grid]]);
template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void
row_reduce_general_small<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint lid [[thread_position_in_grid]]);
template [[host_name("rowGeneralMed_{0}")]] [[kernel]] void
row_reduce_general_med<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("rowGeneral_{0}")]] [[kernel]] void
row_reduce_general<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device mlx_atomic<{2}>* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
)";
constexpr std::string_view reduce_non_atomic_kernels = R"(
template [[host_name("allNoAtomics_{0}")]] [[kernel]] void
all_reduce_no_atomics<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint thread_group_id [[threadgroup_position_in_grid]]);
template [[host_name("colGeneralNoAtomics_{0}")]] [[kernel]] void
col_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup {2}* local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 gid [[thread_position_in_grid]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]]);
template [[host_name("colSmall_{0}")]] [[kernel]] void
col_reduce_small<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
const constant size_t& non_col_reductions [[buffer(8)]],
const constant int* non_col_shapes [[buffer(9)]],
const constant size_t* non_col_strides [[buffer(10)]],
const constant int& non_col_ndim [[buffer(11)]],
uint tid [[thread_position_in_grid]]);
template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void
row_reduce_general_small<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint lid [[thread_position_in_grid]]);
template [[host_name("rowGeneralNoAtomics_{0}")]] [[kernel]] void
row_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
)";

View File

@ -6,7 +6,6 @@
#include "mlx/backend/metal/jit/copy.h"
#include "mlx/backend/metal/jit/gemv_masked.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/reduce.h"
#include "mlx/backend/metal/jit/scan.h"
#include "mlx/backend/metal/jit/softmax.h"
#include "mlx/backend/metal/jit/steel_conv.h"
@ -323,12 +322,13 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
auto lib = d.get_library(kernel_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduce_utils()
<< fmt::format(
reduce_init_kernels,
kernel_name,
get_type_string(out.dtype()),
op_name(out));
std::string op_type = op_name(out);
op_type[0] = std::toupper(op_name(out)[0]);
auto out_type = get_type_string(out.dtype());
std::string op = op_type + "<" + out_type + ">";
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
kernel_source << get_template_definition(
kernel_name, "init_reduce", out_type, op);
lib = d.get_library(kernel_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
@ -347,14 +347,22 @@ MTL::ComputePipelineState* get_reduce_kernel(
op_type[0] = std::toupper(op_name[0]);
bool non_atomic = out.dtype() == int64 || out.dtype() == uint64;
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce()
<< fmt::format(
non_atomic ? reduce_non_atomic_kernels
: reduce_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_type);
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
std::vector<std::pair<std::string, std::string>> reduce_kernels = {
{"all_reduce", "allReduce"},
{"col_reduce_small", "colReduceSmall"},
{"col_reduce_looped", "colReduceLooped"},
{"row_reduce_small", "rowReduceSmall"},
{"row_reduce_looped", "rowReduceLooped"},
{"row_reduce_simple", "rowReduceSimple"}};
std::string op = op_type + "<" + out_type + ">";
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
for (auto [func, name] : reduce_kernels) {
kernel_source << get_template_definition(
name + "_" + lib_name, func, in_type, out_type, op);
}
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);

View File

@ -1,4 +1,5 @@
#pragma once
#include "mlx/backend/metal/kernels/reduction/reduce_all.h"
#include "mlx/backend/metal/kernels/reduction/reduce_col.h"
#include "mlx/backend/metal/kernels/reduction/reduce_init.h"
#include "mlx/backend/metal/kernels/reduction/reduce_row.h"

View File

@ -8,7 +8,6 @@
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/atomic.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_init.h"
#include "mlx/backend/metal/kernels/reduce.h"
#define instantiate_reduce_helper_floats(inst_f, name, op) \
@ -84,9 +83,9 @@
op)
#define instantiate_init_reduce(name, otype, op) \
template [[host_name("i_reduce_" #name)]] [[kernel]] void \
init_reduce<otype, op>( \
device otype * out [[buffer(1)]], uint tid [[thread_position_in_grid]]);
instantiate_kernel("init_reduce_" #name, \
init_reduce, \
otype, op)
#define instantiate_init_reduce_helper(name, tname, type, op) \
instantiate_init_reduce(name##tname, type, op<type>)
@ -98,18 +97,9 @@ instantiate_init_reduce(andbool_, bool, And<bool>)
instantiate_init_reduce(orbool_, bool, Or<bool>)
#define instantiate_all_reduce(name, itype, otype, op) \
template [[host_name("all_reduce_" #name)]] \
[[kernel]] void all_reduce<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& in_size [[buffer(2)]], \
const constant size_t& row_size [[buffer(3)]], \
uint3 gid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
instantiate_kernel("allReduce_" #name, \
all_reduce, \
itype, otype, op)
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
instantiate_all_reduce(name##tname, type, type, op<type>)
@ -124,44 +114,14 @@ instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
template [[host_name("colSmall" #dim "_reduce_" #name)]] \
[[kernel]] void col_reduce_small<itype, otype, op, dim>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant int* shape [[buffer(4)]], \
const constant size_t* strides [[buffer(5)]], \
const constant int& ndim [[buffer(6)]], \
const constant int* reduce_shape [[buffer(7)]], \
const constant size_t* reduce_strides [[buffer(8)]], \
const constant int& reduce_ndim [[buffer(9)]], \
const constant size_t& non_col_reductions [[buffer(10)]], \
uint3 gid [[threadgroup_position_in_grid]], \
uint3 gsize [[threadgroups_per_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[thread_position_in_grid]], \
uint3 tsize [[threads_per_grid]]);
instantiate_kernel("colReduceSmall_" #dim "_reduce_" #name, \
col_reduce_small, \
itype, otype, op, dim)
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
template [[host_name("colLooped" #dim "_" #bm "_" #bn "_reduce_" #name)]] \
[[kernel]] void col_reduce_looped<itype, otype, op, dim, bm, bn>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant int* shape [[buffer(4)]], \
const constant size_t* strides [[buffer(5)]], \
const constant int& ndim [[buffer(6)]], \
const constant int* reduce_shape [[buffer(7)]], \
const constant size_t* reduce_strides [[buffer(8)]], \
const constant int& reduce_ndim [[buffer(9)]], \
const constant size_t& non_col_reductions [[buffer(10)]], \
uint3 gid [[threadgroup_position_in_grid]], \
uint3 gsize [[threadgroups_per_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
instantiate_kernel("colReduceLooped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_looped, \
itype, otype, op, dim, bm, bn)
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 8, 128) \
@ -190,44 +150,14 @@ instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And<boo
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or<bool>)
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
template [[host_name("rowSmall" #dim "_reduce_" #name)]] [[kernel]] void \
row_reduce_small<itype, otype, op, dim>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& row_size [[buffer(2)]], \
const constant size_t& non_row_reductions [[buffer(3)]], \
const constant int* shape [[buffer(4)]], \
const constant size_t* strides [[buffer(5)]], \
const constant int& ndim [[buffer(6)]], \
const constant int* reduce_shape [[buffer(7)]], \
const constant size_t* reduce_strides [[buffer(8)]], \
const constant int& reduce_ndim [[buffer(9)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint3 gid [[threadgroup_position_in_grid]], \
uint3 gsize [[threadgroups_per_grid]], \
uint3 tid [[thread_position_in_grid]], \
uint3 tsize [[threads_per_grid]]);
instantiate_kernel("rowReduceSmall_" #dim "_reduce_" #name, \
row_reduce_small, \
itype, otype, op, dim)
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
template \
[[host_name("rowLooped" #dim "_reduce_" #name)]] [[kernel]] void \
row_reduce_looped<itype, otype, op, dim>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& row_size [[buffer(2)]], \
const constant size_t& non_row_reductions [[buffer(3)]], \
const constant int* shape [[buffer(4)]], \
const constant size_t* strides [[buffer(5)]], \
const constant int& ndim [[buffer(6)]], \
const constant int* reduce_shape [[buffer(7)]], \
const constant size_t* reduce_strides [[buffer(8)]], \
const constant int& reduce_ndim [[buffer(9)]], \
uint3 gid [[threadgroup_position_in_grid]], \
uint3 gsize [[threadgroups_per_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
instantiate_kernel("rowReduceLooped_" #dim "_reduce_" #name, \
row_reduce_looped, \
itype, otype, op, dim)
#define instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op, 0) \
@ -240,20 +170,9 @@ instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or<bool>
instantiate_row_reduce_looped(name, itype, otype, op, 2) \
instantiate_row_reduce_looped(name, itype, otype, op, 3) \
instantiate_row_reduce_looped(name, itype, otype, op, 4) \
template \
[[host_name("rowSimple_reduce_" #name)]] [[kernel]] void \
row_reduce_simple<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
uint3 gid [[threadgroup_position_in_grid]], \
uint3 gsize [[threadgroups_per_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
instantiate_kernel("rowReduceSimple_" #name, \
row_reduce_simple, \
itype, otype, op)
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
instantiate_row_reduce_general(name##tname, type, type, op<type>)

View File

@ -231,8 +231,8 @@ void init_reduce(
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s) {
auto kernel =
get_reduce_init_kernel(d, "i_reduce_" + op_name + type_to_name(out), out);
auto kernel = get_reduce_init_kernel(
d, "init_reduce_" + op_name + type_to_name(out), out);
size_t nthreads = out.size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
@ -255,7 +255,7 @@ void all_reduce_dispatch(
std::vector<array>& copies) {
// Set the kernel
std::ostringstream kname;
kname << "all_reduce_" << op_name << type_to_name(in);
kname << "allReduce_" << op_name << type_to_name(in);
auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out);
compute_encoder->setComputePipelineState(kernel);
@ -309,7 +309,7 @@ void all_reduce_dispatch(
// 2nd pass
std::ostringstream kname_2nd_pass;
kname_2nd_pass << "all_reduce_" << op_name << type_to_name(intermediate);
kname_2nd_pass << "allReduce_" << op_name << type_to_name(intermediate);
auto kernel_2nd_pass =
get_reduce_kernel(d, kname_2nd_pass.str(), op_name, intermediate, out);
compute_encoder->setComputePipelineState(kernel_2nd_pass);
@ -335,7 +335,7 @@ void row_reduce_small(
// Set the kernel
std::ostringstream kname;
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
kname << "rowSmall" << n << "_reduce_" << op_name << type_to_name(in);
kname << "rowReduceSmall_" << n << "_reduce_" << op_name << type_to_name(in);
auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out);
compute_encoder->setComputePipelineState(kernel);
@ -370,7 +370,7 @@ void row_reduce_simple(
const Stream& s) {
// Set the kernel
std::ostringstream kname;
kname << "rowSimple_reduce_" << op_name << type_to_name(in);
kname << "rowReduceSimple_" << op_name << type_to_name(in);
auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out);
compute_encoder->setComputePipelineState(kernel);
@ -407,7 +407,7 @@ void row_reduce_looped(
// Set the kernel
std::ostringstream kname;
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
kname << "rowLooped" << n << "_reduce_" << op_name << type_to_name(in);
kname << "rowReduceLooped_" << n << "_reduce_" << op_name << type_to_name(in);
auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out);
compute_encoder->setComputePipelineState(kernel);
@ -497,7 +497,7 @@ void strided_reduce_small(
// Set the kernel
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
std::ostringstream kname;
kname << "colSmall" << n << "_reduce_" << op_name << type_to_name(in);
kname << "colReduceSmall_" << n << "_reduce_" << op_name << type_to_name(in);
auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out);
compute_encoder->setComputePipelineState(kernel);
@ -535,8 +535,8 @@ void strided_reduce_looped(
// Set the kernel
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
std::ostringstream kname;
kname << "colLooped" << n << "_" << BM << "_" << BN << "_reduce_" << op_name
<< type_to_name(in);
kname << "colReduceLooped_" << n << "_" << BM << "_" << BN << "_reduce_"
<< op_name << type_to_name(in);
auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out);
compute_encoder->setComputePipelineState(kernel);