diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index 6f3d6e2e3..db20f938c 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -102,7 +102,7 @@ inline void build_kernel( // a third grid dimension os += " int64_t index = N_ * (pos.x + grid.x * int64_t(pos.y));\n"; } else if (contiguous) { - os += " int index = N_ * pos.x;\n"; + os += " uint index = N_ * pos.x;\n"; } else if (work_per_thread > 1) { os += fmt::format( " int xshape = output_shape[{0}];\n", @@ -115,6 +115,9 @@ inline void build_kernel( " {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n", idx_type); } + if (work_per_thread > 1 && contiguous) { + os += " for (int i = 0; i < N_ && index < size; ++i) {\n"; + } // Read constant / contiguous inputs in tmps std::vector nc_inputs; @@ -198,13 +201,9 @@ inline void build_kernel( } // Open per-thread loop - if (work_per_thread > 1) { - if (contiguous) { - os += " for (int i = 0; i < N_ && index < size; ++i) {\n"; - } else { - os += - " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n"; - } + if (work_per_thread > 1 && !contiguous) { + os += + " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n"; } // Read non-contiguous inputs into tmps @@ -296,7 +295,7 @@ void Compiled::eval_gpu( /* ndim = */ 0, /* dynamic_dims = */ false, /* use_big_index = */ false, - /* work_per_thread */ work_per_thread); + /* work_per_thread = */ work_per_thread); build_kernel( kernel, kernel_lib_ + "_contiguous_large", @@ -308,7 +307,7 @@ void Compiled::eval_gpu( /* ndim = */ 0, /* dynamic_dims = */ false, /* use_big_index = */ true, - /* work_per_thread */ work_per_thread); + /* work_per_thread = */ work_per_thread); for (int i = 1; i < 8; i++) { build_kernel( kernel, @@ -497,7 +496,7 @@ void Compiled::eval_gpu( // Launch the kernel if (contiguous) { - int work_per_thread = get_work_per_thread(outputs_[0].dtype()); + int work_per_thread = get_work_per_thread(outputs[0].dtype()); size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread); MTL::Size group_dims( std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index d59968453..e81ae1562 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -106,7 +106,6 @@ void ternary_op_gpu_inplace( compute_encoder.dispatch_threads(grid_dims, group_dims); } else { // Launch a 1D or 2D grid of threads - auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads;