diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index 154273233..6f3d6e2e3 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -64,6 +64,7 @@ inline void build_kernel( cnt++); } + std::string idx_type = use_big_index ? "int64_t" : "uint"; if (add_indices) { os += fmt::format( " constant const int64_t* in_strides [[buffer({0})]],\n", cnt++); @@ -83,6 +84,9 @@ inline void build_kernel( " constant const int64_t* output_strides [[buffer({0})]],\n", cnt++); os += fmt::format( " constant const int* output_shape [[buffer({0})]],\n", cnt++); + } else { + os += fmt::format( + " constant const {0}& size [[buffer({1})]],\n", idx_type, cnt++); } if (dynamic_dims) { os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++); @@ -92,13 +96,14 @@ inline void build_kernel( os += " uint3 pos [[thread_position_in_grid]],\n"; os += " uint3 grid [[threads_per_grid]]) {\n"; - std::string idx_type = use_big_index ? "int64_t" : "uint"; + os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread); if (contiguous && use_big_index) { // This is only used for contiguous kernels which don't have // a third grid dimension - os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n"; + os += " int64_t index = N_ * (pos.x + grid.x * int64_t(pos.y));\n"; + } else if (contiguous) { + os += " int index = N_ * pos.x;\n"; } else if (work_per_thread > 1) { - os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread); os += fmt::format( " int xshape = output_shape[{0}];\n", dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)); @@ -194,8 +199,12 @@ inline void build_kernel( // Open per-thread loop if (work_per_thread > 1) { - os += - " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n"; + if (contiguous) { + os += " for (int i = 0; i < N_ && index < size; ++i) {\n"; + } else { + os += + " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n"; + } } // Read non-contiguous inputs into tmps @@ -272,6 +281,7 @@ void Compiled::eval_gpu( auto& s = stream(); auto& d = metal::device(s.device); auto lib = d.get_library(kernel_lib_, [&]() { + int work_per_thread = get_work_per_thread(outputs_[0].dtype()); std::string kernel = metal::utils(); concatenate( kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops()); @@ -284,7 +294,9 @@ void Compiled::eval_gpu( constant_ids_, /* contiguous = */ true, /* ndim = */ 0, - /* dynamic_dims = */ false); + /* dynamic_dims = */ false, + /* use_big_index = */ false, + /* work_per_thread */ work_per_thread); build_kernel( kernel, kernel_lib_ + "_contiguous_large", @@ -295,7 +307,8 @@ void Compiled::eval_gpu( /* contiguous = */ true, /* ndim = */ 0, /* dynamic_dims = */ false, - /* use_big_index = */ true); + /* use_big_index = */ true, + /* work_per_thread */ work_per_thread); for (int i = 1; i < 8; i++) { build_kernel( kernel, @@ -468,6 +481,13 @@ void Compiled::eval_gpu( if (!contiguous) { compute_encoder.set_vector_bytes(strides[0], cnt++); compute_encoder.set_vector_bytes(shape, cnt++); + } else { + auto size = outputs[0].data_size(); + if (large) { + compute_encoder.set_bytes(size, cnt++); + } else { + compute_encoder.set_bytes(size, cnt++); + } } // Put the number of dims in if it is dynamic @@ -477,12 +497,13 @@ void Compiled::eval_gpu( // Launch the kernel if (contiguous) { - size_t nthreads = outputs[0].data_size(); + int work_per_thread = get_work_per_thread(outputs_[0].dtype()); + size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread); MTL::Size group_dims( std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); - MTL::Size grid_dims = large - ? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides()) + ? get_2d_grid_dims( + outputs[0].shape(), outputs[0].strides(), work_per_thread) : MTL::Size(nthreads, 1, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); } else { diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index 0734c2326..c30d186b8 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -15,7 +15,8 @@ typedef half float16_t; -// Work per thread values for different types +// Work per thread values for different types. The values here are expected to +// match get_work_per_thread in mlx/backend/metal/utils.h template struct WorkPerThread { static_assert(sizeof(U) <= 8, "Type too large");