diff --git a/mlx/backend/cuda/compiled.cpp b/mlx/backend/cuda/compiled.cpp index 419f48789..302ca2f99 100644 --- a/mlx/backend/cuda/compiled.cpp +++ b/mlx/backend/cuda/compiled.cpp @@ -332,9 +332,9 @@ void Compiled::eval_gpu( encoder.set_output_array(out); } - auto kernel = mod.get_kernel(kernel_name); + auto [kernel, max_block_dims] = mod.get_kernel_and_dims(kernel_name); auto [num_blocks, block_dims] = - get_launch_args(outputs[0], large, work_per_thread); + get_launch_args(outputs[0], large, work_per_thread, max_block_dims); encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args()); } diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index 4f98d2ebf..9a55bf902 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -297,7 +297,8 @@ void load_module( const std::string& ptx, const std::vector>& ptx_kernels, CUmodule& module_, - std::unordered_map>& kernels) { + std::unordered_map>& + kernels) { // Load module. char jit_log[4089] = {}; CUjit_option options[] = { @@ -314,7 +315,7 @@ void load_module( for (const auto& [name, mangled] : ptx_kernels) { CUfunction kernel; CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str())); - kernels[name] = std::make_pair(kernel, false); + kernels[name] = std::make_tuple(kernel, false, 0); } } @@ -358,7 +359,7 @@ JitModule::~JitModule() { CHECK_CUDA_ERROR(cuModuleUnload(module_)); } -CUfunction JitModule::get_kernel( +std::pair JitModule::get_kernel_and_dims( const std::string& kernel_name, std::function configure_kernel) { auto it = kernels_.find(kernel_name); @@ -369,14 +370,22 @@ CUfunction JitModule::get_kernel( // If it is the first time we run this kernel then configure it. Do it only // once! - if (!it->second.second) { + auto kernel = std::get<0>(it->second); + if (!std::get<1>(it->second)) { if (configure_kernel) { - configure_kernel(it->second.first); + configure_kernel(kernel); } - it->second.second = true; + std::get<1>(it->second) = true; + std::get<2>(it->second) = max_occupancy_block_dim(kernel); } - return it->second.first; + return {kernel, std::get<2>(it->second)}; +} + +CUfunction JitModule::get_kernel( + const std::string& kernel_name, + std::function configure_kernel) { + return get_kernel_and_dims(kernel_name, std::move(configure_kernel)).first; } std::unordered_map& get_jit_module_cache() { diff --git a/mlx/backend/cuda/jit_module.h b/mlx/backend/cuda/jit_module.h index cc569690a..e2fd0c8b8 100644 --- a/mlx/backend/cuda/jit_module.h +++ b/mlx/backend/cuda/jit_module.h @@ -99,10 +99,13 @@ class JitModule { CUfunction get_kernel( const std::string& kernel_name, std::function configure_kernel = nullptr); + std::pair get_kernel_and_dims( + const std::string& kernel_name, + std::function configure_kernel = nullptr); private: CUmodule module_{nullptr}; - std::unordered_map> kernels_; + std::unordered_map> kernels_; }; std::unordered_map& get_jit_module_cache(); diff --git a/mlx/backend/cuda/kernel_utils.cu b/mlx/backend/cuda/kernel_utils.cu index 9ac9a82da..4304318a6 100644 --- a/mlx/backend/cuda/kernel_utils.cu +++ b/mlx/backend/cuda/kernel_utils.cu @@ -35,12 +35,10 @@ std::tuple get_launch_args( const Shape& shape, const Strides& strides, bool large, - int work_per_thread) { + int work_per_thread /* = 1 */, + uint max_block_dim /* = 1024 */) { size_t nthreads = cuda::ceil_div(size, work_per_thread); - uint block_dim = 1024; - if (block_dim > nthreads) { - block_dim = nthreads; - } + uint block_dim = max_block_dim < nthreads ? max_block_dim : nthreads; dim3 num_blocks; if (large) { num_blocks = get_2d_grid_dims(shape, strides, work_per_thread); diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index 7a37361ea..9fca29116 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -120,19 +120,28 @@ dim3 get_2d_grid_dims( size_t divisor); std::pair get_grid_and_block(int dim0, int dim1, int dim2); -// Get the num_blocks and block_dims that maximize occupancy for |kernel|, -// assuming each thread handles |work_per_thread| elements of |arr|. +// Get the num_blocks and block_dims assuming each thread handles +// |work_per_thread| elements of |arr|. std::tuple get_launch_args( size_t size, const Shape& shape, const Strides& strides, bool large, - int work_per_thread = 1); + int work_per_thread = 1, + uint max_block_dim = 1024); -inline std::tuple -get_launch_args(const array& arr, bool large, int work_per_thread = 1) { +inline std::tuple get_launch_args( + const array& arr, + bool large, + int work_per_thread = 1, + uint max_block_dim = 1024) { return get_launch_args( - arr.size(), arr.shape(), arr.strides(), large, work_per_thread); + arr.size(), + arr.shape(), + arr.strides(), + large, + work_per_thread, + max_block_dim); } } // namespace mlx::core diff --git a/mlx/backend/cuda/utils.h b/mlx/backend/cuda/utils.h index e811d5e6c..81b19e346 100644 --- a/mlx/backend/cuda/utils.h +++ b/mlx/backend/cuda/utils.h @@ -12,6 +12,7 @@ namespace mlx::core { namespace cu { class Device; + } struct Dtype; @@ -86,4 +87,17 @@ class CudaStream : public CudaHandle { explicit CudaStream(cu::Device& device); }; +template +inline uint max_occupancy_block_dim(T kernel) { + int _, block_dim; + if constexpr (std::is_same_v) { + CHECK_CUDA_ERROR( + cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0)); + } else { + CHECK_CUDA_ERROR( + cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel)); + } + return block_dim; +} + } // namespace mlx::core diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 31fd38588..5528b094e 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -828,6 +828,19 @@ class TestCompile(mlx_tests.MLXTestCase): out = fun(arrs) self.assertTrue(mx.allclose(out, mx.array([64.0, 128.0]))) + inputs = [mx.arange(16384).astype(mx.float16) for _ in range(8)] + + def fun(inputs): + a = inputs[0] + inputs[1] + b = inputs[2] + inputs[3] + c = inputs[4] + inputs[5] + d = inputs[6] + inputs[7] + return a * b * c * d + + out = mx.compile(fun)(inputs) + expected = fun(inputs) + self.assertTrue(mx.allclose(out, expected)) + def test_compile_many_outputs(self): @mx.compile