From ec0d5db67b44916ee7706ef2ac624d642510bdac Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 2 Jul 2025 15:59:13 -0700 Subject: [PATCH] [CUDA] Switch to CUDA graphs (#2317) * cuda graph prototype fix signal bug + start to add dependencies capture more capture more ops remaining ops fix reduce and rope deps add concurrent context try update, but not working cosistent topology order use node api use node api directly to reduce overhead fix bug use kernels in unary cache graph format fix synchronization format * comment --- mlx/backend/common/matmul.h | 23 +- mlx/backend/cuda/arg_reduce.cu | 47 ++- mlx/backend/cuda/binary.cu | 154 +++++----- mlx/backend/cuda/binary_two.cu | 167 +++++----- mlx/backend/cuda/compiled.cpp | 21 +- mlx/backend/cuda/copy/copy_contiguous.cu | 37 +-- mlx/backend/cuda/copy/copy_general.cu | 82 ++--- mlx/backend/cuda/copy/copy_general_dynamic.cu | 83 ++--- mlx/backend/cuda/copy/copy_general_input.cu | 72 ++--- mlx/backend/cuda/device.cpp | 289 ++++++++++++++---- mlx/backend/cuda/device.h | 170 ++++++----- mlx/backend/cuda/eval.cpp | 28 +- mlx/backend/cuda/event.cu | 20 +- mlx/backend/cuda/indexing.cpp | 154 +++++----- mlx/backend/cuda/jit_module.cpp | 65 +--- mlx/backend/cuda/jit_module.h | 76 +++-- mlx/backend/cuda/kernel_utils.cuh | 9 +- mlx/backend/cuda/layer_norm.cu | 106 ++++--- mlx/backend/cuda/logsumexp.cu | 22 +- mlx/backend/cuda/matmul.cpp | 73 ++--- mlx/backend/cuda/primitives.cu | 30 +- mlx/backend/cuda/random.cu | 61 ++-- mlx/backend/cuda/reduce/all_reduce.cu | 51 ++-- mlx/backend/cuda/reduce/col_reduce.cu | 36 ++- mlx/backend/cuda/reduce/init_reduce.cu | 22 +- mlx/backend/cuda/reduce/row_reduce.cu | 108 ++++--- mlx/backend/cuda/rms_norm.cu | 93 +++--- mlx/backend/cuda/rope.cu | 151 ++++----- mlx/backend/cuda/softmax.cu | 28 +- mlx/backend/cuda/sort.cu | 157 +++++----- mlx/backend/cuda/ternary.cu | 127 ++++---- mlx/backend/cuda/unary.cu | 97 ++++-- mlx/backend/cuda/utils.cpp | 8 + mlx/backend/cuda/utils.h | 2 + mlx/linalg.cpp | 2 +- python/tests/test_load.py | 2 + 36 files changed, 1461 insertions(+), 1212 deletions(-) diff --git a/mlx/backend/common/matmul.h b/mlx/backend/common/matmul.h index 2e0261a30..2faf256d1 100644 --- a/mlx/backend/common/matmul.h +++ b/mlx/backend/common/matmul.h @@ -12,16 +12,11 @@ namespace mlx::core { inline std::tuple collapse_batches( const array& a, const array& b) { - // Get and check the shape for the batched dims - Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; - Shape B_bshape{b.shape().begin(), b.shape().end() - 2}; - if (A_bshape != B_bshape) { - std::ostringstream msg; - msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A " - << a.shape() << ", B " << b.shape() << "."; - throw std::runtime_error(msg.str()); + if (a.ndim() == 2) { + return {{1}, {0}, {0}}; } + Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; Strides A_bstride{a.strides().begin(), a.strides().end() - 2}; Strides B_bstride{b.strides().begin(), b.strides().end() - 2}; @@ -42,17 +37,11 @@ inline std::tuple collapse_batches( inline std::tuple collapse_batches(const array& a, const array& b, const array& c) { - // Get and check the shape for the batched dims - Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; - Shape B_bshape{b.shape().begin(), b.shape().end() - 2}; - Shape C_bshape{c.shape().begin(), c.shape().end() - 2}; - if (A_bshape != B_bshape || A_bshape != C_bshape) { - std::ostringstream msg; - msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A " - << a.shape() << ", B " << b.shape() << ", B " << c.shape() << "."; - throw std::runtime_error(msg.str()); + if (a.ndim() == 2) { + return {{1}, {0}, {0}, {0}}; } + Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; Strides A_bstride{a.strides().begin(), a.strides().end() - 2}; Strides B_bstride{b.strides().begin(), b.strides().end() - 2}; Strides C_bstride{c.strides().begin(), c.strides().end() - 2}; diff --git a/mlx/backend/cuda/arg_reduce.cu b/mlx/backend/cuda/arg_reduce.cu index 90f8561c1..ad942a406 100644 --- a/mlx/backend/cuda/arg_reduce.cu +++ b/mlx/backend/cuda/arg_reduce.cu @@ -151,30 +151,29 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(in); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) { - using T = cuda_type_t; - constexpr uint32_t N_READS = 4; - dispatch_block_dim( - cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides()); - auto kernel = - cu::arg_reduce_general, block_dim(), N_READS>; - if (reduce_type_ == ArgReduce::ArgMin) { - kernel = cu:: - arg_reduce_general, block_dim(), N_READS>; - } - kernel<<>>( - in.data(), - out.data(), - out.size(), - const_param(shape), - const_param(in_strides), - const_param(out_strides), - ndim, - axis_stride, - axis_size); - }); + dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) { + using T = cuda_type_t; + constexpr uint32_t N_READS = 4; + dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides()); + auto kernel = + cu::arg_reduce_general, block_dim(), N_READS>; + if (reduce_type_ == ArgReduce::ArgMin) { + kernel = cu::arg_reduce_general, block_dim(), N_READS>; + } + encoder.add_kernel_node( + kernel, + num_blocks, + block_dim(), + in.data(), + out.data(), + out.size(), + const_param(shape), + const_param(in_strides), + const_param(out_strides), + ndim, + axis_stride, + axis_size); }); }); } diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index 8e476d30f..d9b9fd8af 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -139,90 +139,92 @@ void binary_op_gpu_inplace( encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(a.dtype(), [&](auto in_type_tag) { - dispatch_all_types(out.dtype(), [&](auto out_type_tag) { - using CTYPE_IN = MLX_GET_TYPE(in_type_tag); - using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); - if constexpr (cu::supports_binary_op()) { - using InType = cuda_type_t; - using OutType = cuda_type_t; - auto bopt = get_binary_op_type(a, b); - if (bopt == BinaryOpType::General) { - dispatch_bool( - a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || - out.data_size() > INT32_MAX, - [&](auto large) { - using IdxT = std::conditional_t; - Shape shape; - std::vector strides; - std::tie(shape, strides) = - collapse_contiguous_dims(a, b, out); - auto& a_strides = strides[0]; - auto& b_strides = strides[1]; - int ndim = shape.size(); - if (ndim <= 3) { - dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto kernel = cu::binary_g_nd< - Op, - InType, - OutType, - IdxT, - dims_constant()>; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large()); - kernel<<>>( - a.data(), - b.data(), - out.data(), - out.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides)); - }); - } else { - auto kernel = cu::binary_g; + dispatch_all_types(a.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + using CTYPE_IN = MLX_GET_TYPE(in_type_tag); + using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); + if constexpr (cu::supports_binary_op()) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + auto bopt = get_binary_op_type(a, b); + if (bopt == BinaryOpType::General) { + dispatch_bool( + a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || + out.data_size() > INT32_MAX, + [&](auto large) { + using IdxT = std::conditional_t; + Shape shape; + std::vector strides; + std::tie(shape, strides) = collapse_contiguous_dims(a, b, out); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = cu:: + binary_g_nd; auto [num_blocks, block_dims] = get_launch_args(kernel, out, large()); - kernel<<>>( + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, a.data(), b.data(), out.data(), out.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides), - ndim); - } - }); - } else { - dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { - using IdxT = std::conditional_t; - auto kernel = cu::binary_ss; - if (bopt == BinaryOpType::ScalarVector) { - kernel = cu::binary_sv; - } else if (bopt == BinaryOpType::VectorScalar) { - kernel = cu::binary_vs; - } else if (bopt == BinaryOpType::VectorVector) { - kernel = cu::binary_vv; - } - auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), large()); - kernel<<>>( - a.data(), - b.data(), - out.data(), - out.data_size()); - }); - } + const_param(shape), + const_param(a_strides), + const_param(b_strides)); + }); + } else { + auto kernel = cu::binary_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + a.data(), + b.data(), + out.data(), + out.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + ndim); + } + }); } else { - throw std::runtime_error(fmt::format( - "Can not do binary op {} on inputs of {} with result of {}.", - op, - dtype_to_string(a.dtype()), - dtype_to_string(out.dtype()))); + dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { + using IdxT = std::conditional_t; + auto kernel = cu::binary_ss; + if (bopt == BinaryOpType::ScalarVector) { + kernel = cu::binary_sv; + } else if (bopt == BinaryOpType::VectorScalar) { + kernel = cu::binary_vs; + } else if (bopt == BinaryOpType::VectorVector) { + kernel = cu::binary_vv; + } + auto [num_blocks, block_dims] = get_launch_args( + kernel, out.data_size(), out.shape(), out.strides(), large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + a.data(), + b.data(), + out.data(), + out.data_size()); + }); } - }); + } else { + throw std::runtime_error(fmt::format( + "Can not do binary op {} on inputs of {} with result of {}.", + op, + dtype_to_string(a.dtype()), + dtype_to_string(out.dtype()))); + } }); }); } diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index 0a68e5f1d..9582b0378 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -137,98 +137,101 @@ void binary_op_gpu_inplace( encoder.set_input_array(b); encoder.set_output_array(out_a); encoder.set_output_array(out_b); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(a.dtype(), [&](auto in_type_tag) { - dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) { - using CTYPE_IN = MLX_GET_TYPE(in_type_tag); - using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); - if constexpr (cu::supports_binary_op()) { - using InType = cuda_type_t; - using OutType = cuda_type_t; + dispatch_all_types(a.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) { + using CTYPE_IN = MLX_GET_TYPE(in_type_tag); + using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); + if constexpr (cu::supports_binary_op()) { + using InType = cuda_type_t; + using OutType = cuda_type_t; - auto bopt = get_binary_op_type(a, b); - if (bopt == BinaryOpType::General) { - dispatch_bool( - a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || - out_a.data_size() > INT32_MAX, - [&](auto large) { - using IdxT = std::conditional_t; - Shape shape; - std::vector strides; - std::tie(shape, strides) = - collapse_contiguous_dims(a, b, out_a); - auto& a_strides = strides[0]; - auto& b_strides = strides[1]; - int ndim = shape.size(); - if (ndim <= 3) { - dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto kernel = cu::binary_g_nd< - Op, - InType, - OutType, - IdxT, - dims_constant()>; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out_a, large()); - kernel<<>>( - a.data(), - b.data(), - out_a.data(), - out_b.data(), - out_a.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides)); - }); - } else { - auto kernel = cu::binary_g; + auto bopt = get_binary_op_type(a, b); + if (bopt == BinaryOpType::General) { + dispatch_bool( + a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || + out_a.data_size() > INT32_MAX, + [&](auto large) { + using IdxT = std::conditional_t; + Shape shape; + std::vector strides; + std::tie(shape, strides) = + collapse_contiguous_dims(a, b, out_a); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = cu:: + binary_g_nd; auto [num_blocks, block_dims] = get_launch_args(kernel, out_a, large()); - kernel<<>>( + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, a.data(), b.data(), out_a.data(), out_b.data(), out_a.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides), - ndim); - } - }); - } else { - dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) { - using IdxT = std::conditional_t; - auto kernel = cu::binary_ss; - if (bopt == BinaryOpType::ScalarVector) { - kernel = cu::binary_sv; - } else if (bopt == BinaryOpType::VectorScalar) { - kernel = cu::binary_vs; - } else if (bopt == BinaryOpType::VectorVector) { - kernel = cu::binary_vv; - } - auto [num_blocks, block_dims] = get_launch_args( - kernel, - out_a.data_size(), - out_a.shape(), - out_a.strides(), - large()); - kernel<<>>( - a.data(), - b.data(), - out_a.data(), - out_b.data(), - out_a.data_size()); - }); - } + const_param(shape), + const_param(a_strides), + const_param(b_strides)); + }); + } else { + auto kernel = cu::binary_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out_a, large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + a.data(), + b.data(), + out_a.data(), + out_b.data(), + out_a.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + ndim); + } + }); } else { - throw std::runtime_error(fmt::format( - "Can not do binary op {} on inputs of {} with result of {}.", - op, - dtype_to_string(a.dtype()), - dtype_to_string(out_a.dtype()))); + dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) { + using IdxT = std::conditional_t; + auto kernel = cu::binary_ss; + if (bopt == BinaryOpType::ScalarVector) { + kernel = cu::binary_sv; + } else if (bopt == BinaryOpType::VectorScalar) { + kernel = cu::binary_vs; + } else if (bopt == BinaryOpType::VectorVector) { + kernel = cu::binary_vv; + } + auto [num_blocks, block_dims] = get_launch_args( + kernel, + out_a.data_size(), + out_a.shape(), + out_a.strides(), + large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + a.data(), + b.data(), + out_a.data(), + out_b.data(), + out_a.data_size()); + }); } - }); + } else { + throw std::runtime_error(fmt::format( + "Can not do binary op {} on inputs of {} with result of {}.", + op, + dtype_to_string(a.dtype()), + dtype_to_string(out_a.dtype()))); + } }); }); } diff --git a/mlx/backend/cuda/compiled.cpp b/mlx/backend/cuda/compiled.cpp index 1aa7ecb92..21257e5dd 100644 --- a/mlx/backend/cuda/compiled.cpp +++ b/mlx/backend/cuda/compiled.cpp @@ -3,6 +3,7 @@ #include "mlx/backend/common/compiled.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/jit_module.h" +#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/graph_utils.h" #include "mlx/primitives.h" @@ -178,6 +179,7 @@ void Compiled::eval_gpu( // Whether to use large index. bool large = compiled_use_large_index(inputs, outputs, contiguous); + cu::KernelArgs args; // Put inputs. int strides_index = 1; for (size_t i = 0; i < inputs.size(); ++i) { @@ -185,26 +187,26 @@ void Compiled::eval_gpu( continue; } const auto& x = inputs[i]; - mod.append_arg(x); + args.append(x); if (!contiguous && !is_scalar(x)) { - mod.append_arg(strides_vec[strides_index++]); + args.append_ptr(strides_vec[strides_index++].data()); } } // Put outputs. compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous); for (auto& x : outputs) { - mod.append_arg(x); + args.append(x); } // Put shape and size. if (!contiguous) { - mod.append_arg(shape); + args.append_ptr(shape.data()); } if (large) { - mod.append_arg(outputs[0].data_size()); + args.append(outputs[0].data_size()); } else { - mod.append_arg(outputs[0].data_size()); + args.append(outputs[0].data_size()); } // Launch kernel. @@ -222,9 +224,10 @@ void Compiled::eval_gpu( for (const auto& out : outputs) { encoder.set_output_array(out); } - encoder.launch_kernel([&](cudaStream_t stream) { - mod.launch_kernel(stream, kernel_name, outputs[0], large); - }); + + auto kernel = mod.get_kernel(kernel_name); + auto [num_blocks, block_dims] = get_launch_args(kernel, outputs[0], large); + encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); } } // namespace mlx::core diff --git a/mlx/backend/cuda/copy/copy_contiguous.cu b/mlx/backend/cuda/copy/copy_contiguous.cu index 15858ded0..408350129 100644 --- a/mlx/backend/cuda/copy/copy_contiguous.cu +++ b/mlx/backend/cuda/copy/copy_contiguous.cu @@ -35,24 +35,25 @@ void copy_contiguous( array& out, int64_t in_offset, int64_t out_offset) { - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto in_type_tag) { - dispatch_all_types(out.dtype(), [&](auto out_type_tag) { - dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { - using InType = cuda_type_t; - using OutType = cuda_type_t; - using IdxT = std::conditional_t; - auto kernel = cu::copy_s; - if (ctype == CopyType::Vector) { - kernel = cu::copy_v; - } - auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), large()); - kernel<<>>( - in.data() + in_offset, - out.data() + out_offset, - out.data_size()); - }); + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + using IdxT = std::conditional_t; + auto kernel = cu::copy_s; + if (ctype == CopyType::Vector) { + kernel = cu::copy_v; + } + auto [num_blocks, block_dims] = get_launch_args( + kernel, out.data_size(), out.shape(), out.strides(), large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + in.data() + in_offset, + out.data() + out_offset, + out.data_size()); }); }); }); diff --git a/mlx/backend/cuda/copy/copy_general.cu b/mlx/backend/cuda/copy/copy_general.cu index b2703e4bf..5c7f9f954 100644 --- a/mlx/backend/cuda/copy/copy_general.cu +++ b/mlx/backend/cuda/copy/copy_general.cu @@ -55,50 +55,54 @@ void copy_general( const Shape& shape, const Strides& strides_in, const Strides& strides_out) { - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto in_type_tag) { - dispatch_all_types(out.dtype(), [&](auto out_type_tag) { - dispatch_bool( - in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, - [&](auto large) { - using InType = cuda_type_t; - using OutType = cuda_type_t; - using IdxT = std::conditional_t; - const InType* in_ptr = in.data() + offset_in; - OutType* out_ptr = out.data() + offset_out; - int ndim = shape.size(); - size_t data_size = 1; - for (auto& s : shape) - data_size *= s; - if (ndim <= 3) { - dispatch_1_2_3(ndim, [&](auto ndim_constant) { - auto kernel = - cu::copy_gg_nd; - auto [num_blocks, block_dims] = get_launch_args( - kernel, data_size, shape, out.strides(), large()); - kernel<<>>( - in_ptr, - out_ptr, - data_size, - const_param(shape), - const_param(strides_in), - const_param(strides_out)); - }); - } else { // ndim >= 4 - auto kernel = cu::copy_gg; + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + dispatch_bool( + in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, + [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + using IdxT = std::conditional_t; + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + int ndim = shape.size(); + size_t data_size = 1; + for (auto& s : shape) + data_size *= s; + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto ndim_constant) { + auto kernel = + cu::copy_gg_nd; auto [num_blocks, block_dims] = get_launch_args( kernel, data_size, shape, out.strides(), large()); - kernel<<>>( + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, in_ptr, out_ptr, data_size, - const_param(shape), - const_param(strides_in), - const_param(strides_out), - ndim); - } - }); - }); + const_param(shape), + const_param(strides_in), + const_param(strides_out)); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_gg; + auto [num_blocks, block_dims] = get_launch_args( + kernel, data_size, shape, out.strides(), large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + in_ptr, + out_ptr, + data_size, + const_param(shape), + const_param(strides_in), + const_param(strides_out), + ndim); + } + }); }); }); } diff --git a/mlx/backend/cuda/copy/copy_general_dynamic.cu b/mlx/backend/cuda/copy/copy_general_dynamic.cu index 68ad005d2..1b643111f 100644 --- a/mlx/backend/cuda/copy/copy_general_dynamic.cu +++ b/mlx/backend/cuda/copy/copy_general_dynamic.cu @@ -61,54 +61,55 @@ void copy_general_dynamic( const Strides& strides_out, const array& dynamic_offset_in, const array& dynamic_offset_out) { - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto in_type_tag) { - dispatch_all_types(out.dtype(), [&](auto out_type_tag) { - dispatch_bool( - in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, - [&](auto large) { - using InType = cuda_type_t; - using OutType = cuda_type_t; - using IdxT = std::conditional_t; - const InType* in_ptr = in.data() + offset_in; - OutType* out_ptr = out.data() + offset_out; - int ndim = shape.size(); - if (ndim <= 3) { - dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto kernel = cu::copy_gg_dynamic_nd< - InType, - OutType, - IdxT, - dims_constant()>; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large()); - kernel<<>>( - in_ptr, - out_ptr, - out.size(), - const_param(shape), - const_param(strides_in), - const_param(strides_out), - dynamic_offset_in.data(), - dynamic_offset_out.data()); - }); - } else { // ndim >= 4 - auto kernel = cu::copy_gg_dynamic; + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + dispatch_bool( + in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, + [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + using IdxT = std::conditional_t; + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = cu:: + copy_gg_dynamic_nd; auto [num_blocks, block_dims] = get_launch_args(kernel, out, large()); - kernel<<>>( + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, in_ptr, out_ptr, out.size(), - const_param(shape), - const_param(strides_in), - const_param(strides_out), - ndim, + const_param(shape), + const_param(strides_in), + const_param(strides_out), dynamic_offset_in.data(), dynamic_offset_out.data()); - } - }); - }); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_gg_dynamic; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + in_ptr, + out_ptr, + out.size(), + const_param(shape), + const_param(strides_in), + const_param(strides_out), + ndim, + dynamic_offset_in.data(), + dynamic_offset_out.data()); + } + }); }); }); } diff --git a/mlx/backend/cuda/copy/copy_general_input.cu b/mlx/backend/cuda/copy/copy_general_input.cu index d83ba0854..1ac7712e6 100644 --- a/mlx/backend/cuda/copy/copy_general_input.cu +++ b/mlx/backend/cuda/copy/copy_general_input.cu @@ -50,45 +50,49 @@ void copy_general_input( int64_t offset_out, const Shape& shape, const Strides& strides_in) { - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto in_type_tag) { - dispatch_all_types(out.dtype(), [&](auto out_type_tag) { - dispatch_bool( - in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, - [&](auto large) { - using InType = cuda_type_t; - using OutType = cuda_type_t; - using IdxT = std::conditional_t; - const InType* in_ptr = in.data() + offset_in; - OutType* out_ptr = out.data() + offset_out; - int ndim = shape.size(); - if (ndim <= 3) { - dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto kernel = - cu::copy_g_nd; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large()); - kernel<<>>( - in_ptr, - out_ptr, - out.size(), - const_param(shape), - const_param(strides_in)); - }); - } else { // ndim >= 4 - auto kernel = cu::copy_g; + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + dispatch_bool( + in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, + [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + using IdxT = std::conditional_t; + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = + cu::copy_g_nd; auto [num_blocks, block_dims] = get_launch_args(kernel, out, large()); - kernel<<>>( + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, in_ptr, out_ptr, out.size(), - const_param(shape), - const_param(strides_in), - ndim); - } - }); - }); + const_param(shape), + const_param(strides_in)); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + in_ptr, + out_ptr, + out.size(), + const_param(shape), + const_param(strides_in), + ndim); + } + }); }); }); } diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index ba31c0e45..fff752fe5 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -2,38 +2,23 @@ #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/worker.h" -#include "mlx/backend/metal/metal.h" +#include "mlx/utils.h" #include #include #include +#include namespace mlx::core { +// Can be tuned with MLX_MAX_OPS_PER_BUFFER +// This should be less than 255 +constexpr int default_max_nodes_per_graph = 20; + +constexpr int max_graph_cache_size = 100; + namespace cu { -DeviceStream::DeviceStream(Device& device) : device_(device), stream_(device) {} - -void DeviceStream::synchronize() { - cudaStreamSynchronize(stream_); -} - -cudaStream_t DeviceStream::schedule_cuda_stream() { - // TODO: Return a stream that maximizes parallelism. - return stream_; -} - -cudaStream_t DeviceStream::last_cuda_stream() { - return stream_; -} - -CommandEncoder& DeviceStream::get_encoder() { - if (!encoder_) { - encoder_ = std::make_unique(*this); - } - return *encoder_; -} - Device::Device(int device) : device_(device) { CHECK_CUDA_ERROR(cudaDeviceGetAttribute( &compute_capability_major_, cudaDevAttrComputeCapabilityMajor, device_)); @@ -67,49 +52,253 @@ void Device::make_current() { } } -DeviceStream& Device::get_stream(Stream s) { - auto it = streams_.find(s.index); - if (it == streams_.end()) { - it = streams_.try_emplace(s.index, *this).first; +CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) { + CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0)); + CHECK_CUDA_ERROR(cudaStreamBeginCaptureToGraph( + enc.stream(), graph, NULL, NULL, 0, cudaStreamCaptureModeGlobal)); +} + +CommandEncoder::CaptureContext::~CaptureContext() { + CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph)); + size_t num_nodes; + CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes)); + if (num_nodes == 1) { + cudaGraphNode_t captured_node; + CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes)); + CUDA_KERNEL_NODE_PARAMS params; + CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, ¶ms)); + cudaGraphNode_t node; + CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, enc.graph_, NULL, 0, ¶ms)); + enc.insert_graph_dependencies(GraphNode{node, 'K'}); + } else { + cudaGraphNode_t node; + CHECK_CUDA_ERROR( + cudaGraphAddChildGraphNode(&node, enc.graph_, NULL, 0, graph)); + enc.insert_graph_dependencies(GraphNode{node, 'G'}); + } + CHECK_CUDA_ERROR(cudaGraphDestroy(graph)); +} + +CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc) + : enc(enc) { + enc.in_concurrent_ = true; +} + +CommandEncoder::ConcurrentContext::~ConcurrentContext() { + enc.in_concurrent_ = false; + + // Use an empty graph node for synchronization + CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)}; + enc.empty_node_count_++; + CHECK_CUDA_ERROR(cudaGraphAddEmptyNode(&empty.node, enc.graph_, NULL, 0)); + + // Insert the concurrent -> empty node dependencies + for (auto& from : enc.concurrent_nodes_) { + enc.from_nodes_.push_back(from.node); + enc.to_nodes_.push_back(empty.node); + enc.graph_key_ += from.id; + enc.graph_key_ += from.node_type; + enc.graph_key_ += empty.id; + enc.graph_key_ += empty.node_type; + } + + // Insert the input -> concurrent node dependencies without updating output + // nodes + auto outputs = std::move(enc.active_outputs_); + enc.insert_graph_dependencies(std::move(enc.concurrent_nodes_)); + + // Update output node to be the empty node + for (auto o : outputs) { + enc.node_map_.emplace(o, empty).first->second = empty; + } +} + +void CommandEncoder::insert_graph_dependencies(GraphNode node) { + if (node.node_type == 'G') { + graph_node_count_++; + } + node.id = std::to_string(node_count_++); + if (in_concurrent_) { + concurrent_nodes_.push_back(std::move(node)); + } else { + std::vector nodes; + nodes.push_back(std::move(node)); + insert_graph_dependencies(std::move(nodes)); + } +} + +void CommandEncoder::insert_graph_dependencies(std::vector nodes) { + std::vector deps; + { + // Dependencies must be added in the same order to produce a consistent + // topology + std::unordered_set set_deps; + for (auto d : active_deps_) { + if (auto it = node_map_.find(d); it != node_map_.end()) { + auto [_, inserted] = set_deps.insert(it->second.node); + if (inserted) { + deps.push_back(it->second); + } + } + } + } + active_deps_.clear(); + + for (auto o : active_outputs_) { + for (auto& node : nodes) { + node_map_.emplace(o, node).first->second = node; + } + } + active_outputs_.clear(); + + for (auto& from : deps) { + for (auto& to : nodes) { + from_nodes_.push_back(from.node); + to_nodes_.push_back(to.node); + graph_key_ += from.id; + graph_key_ += from.node_type; + graph_key_ += to.id; + graph_key_ += to.node_type; + } + } +} + +CommandEncoder& Device::get_command_encoder(Stream s) { + auto it = encoders_.find(s.index); + if (it == encoders_.end()) { + it = encoders_.try_emplace(s.index, *this).first; } return it->second; } -CommandEncoder::CommandEncoder(DeviceStream& s) - : device_(s.device()), stream_(s) {} +CommandEncoder::CommandEncoder(Device& d) : stream_(d) { + CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0)); +} + +void clear_graphs(std::unordered_map& graphs) { + for (auto& [_, graph_exec] : graphs) { + CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec)); + } + graphs.clear(); +} + +CommandEncoder::~CommandEncoder() { + clear_graphs(graph_cache_); +} void CommandEncoder::add_completed_handler(std::function task) { worker_.add_task(std::move(task)); } -void CommandEncoder::end_encoding() { - if (!temporaries_.empty()) { - add_completed_handler([temporaries = std::move(temporaries_)]() {}); - } +void CommandEncoder::set_input_array(const array& arr) { + auto id = reinterpret_cast(arr.buffer().ptr()); + active_deps_.push_back(id); +} - // There is no kernel running, run completion handlers immediately. - if (!has_gpu_work_) { - worker_.consume_in_this_thread(); - return; - } - has_gpu_work_ = false; +void CommandEncoder::set_output_array(const array& arr) { + auto id = reinterpret_cast(arr.buffer().ptr()); + active_deps_.push_back(id); + active_outputs_.push_back(id); +} - // Put completion handlers in a batch. - worker_.end_batch(); - - // Signaling kernel completion is expensive, delay until enough batches. - // TODO: This number is arbitrarily picked, profile for a better stragety. - if (worker_.uncommited_batches() > 8) { +void CommandEncoder::maybe_commit() { + if (node_count_ >= env::max_ops_per_buffer(default_max_nodes_per_graph)) { commit(); } } +void CommandEncoder::add_kernel_node( + void* func, + dim3 grid_dim, + dim3 block_dim, + void** params) { + cudaKernelNodeParams kernel_params = {0}; + kernel_params.func = func; + kernel_params.gridDim = grid_dim; + kernel_params.blockDim = block_dim; + kernel_params.kernelParams = params; + cudaGraphNode_t node; + CHECK_CUDA_ERROR( + cudaGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params)); + insert_graph_dependencies(GraphNode{node, 'K'}); +} + +void CommandEncoder::add_kernel_node( + CUfunction func, + dim3 grid_dim, + dim3 block_dim, + void** params) { + CUDA_KERNEL_NODE_PARAMS kernel_params = {0}; + kernel_params.func = func; + kernel_params.gridDimX = grid_dim.x; + kernel_params.gridDimY = grid_dim.y; + kernel_params.gridDimZ = grid_dim.z; + kernel_params.blockDimX = block_dim.x; + kernel_params.blockDimY = block_dim.y; + kernel_params.blockDimZ = block_dim.z; + kernel_params.kernelParams = params; + CUgraphNode node; + CHECK_CUDA_ERROR( + cuGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params)); + insert_graph_dependencies(GraphNode{node, 'K'}); +} + void CommandEncoder::commit() { - worker_.commit(stream_.last_cuda_stream()); + if (!temporaries_.empty()) { + add_completed_handler([temporaries = std::move(temporaries_)]() {}); + } + if (node_count_ > 0) { + if (!from_nodes_.empty()) { + CHECK_CUDA_ERROR(cudaGraphAddDependencies( + graph_, from_nodes_.data(), to_nodes_.data(), from_nodes_.size())); + } + // TODO smarter cache policy + if (graph_cache_.size() > max_graph_cache_size) { + clear_graphs(graph_cache_); + } + + graph_key_ += "."; + graph_key_ += std::to_string(node_count_); + graph_key_ += "."; + graph_key_ += std::to_string(graph_node_count_); + graph_key_ += "."; + graph_key_ += std::to_string(empty_node_count_); + auto [it, _] = graph_cache_.emplace(graph_key_, nullptr); + auto& graph_exec = it->second; + + if (graph_exec != NULL) { + cudaGraphExecUpdateResultInfo update_result; + cudaGraphExecUpdate(graph_exec, graph_, &update_result); + if (update_result.result != cudaGraphExecUpdateSuccess) { + cudaGetLastError(); + CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec)); + graph_exec = NULL; + } + } + if (graph_exec == NULL) { + CHECK_CUDA_ERROR( + cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0)); + } + CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_)); + + // Reset state + node_count_ = 0; + graph_node_count_ = 0; + from_nodes_.clear(); + to_nodes_.clear(); + graph_key_.clear(); + node_map_.clear(); + CHECK_CUDA_ERROR(cudaGraphDestroy(graph_)); + CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0)); + } + + // Put completion handlers in a batch. + worker_.end_batch(); + worker_.commit(stream_); } void CommandEncoder::synchronize() { - stream().synchronize(); + cudaStreamSynchronize(stream_); auto p = std::make_shared>(); std::future f = p->get_future(); add_completed_handler([p = std::move(p)]() { p->set_value(); }); @@ -127,12 +316,8 @@ Device& device(mlx::core::Device device) { return it->second; } -DeviceStream& get_stream(Stream s) { - return device(s.device).get_stream(s); -} - CommandEncoder& get_command_encoder(Stream s) { - return get_stream(s).get_encoder(); + return device(s.device).get_command_encoder(s); } } // namespace cu diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index 744f77f62..4ebdae55c 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -7,41 +7,108 @@ #include "mlx/stream.h" #include +#include #include #include namespace mlx::core::cu { -class Device; -class CommandEncoder; - -class DeviceStream { +class CommandEncoder { public: - explicit DeviceStream(Device& device); + struct CaptureContext { + CaptureContext(CommandEncoder& enc); + ~CaptureContext(); + cudaGraph_t graph; + CommandEncoder& enc; + }; + struct ConcurrentContext { + ConcurrentContext(CommandEncoder& enc); + ~ConcurrentContext(); + CommandEncoder& enc; + }; - DeviceStream(const DeviceStream&) = delete; - DeviceStream& operator=(const DeviceStream&) = delete; + explicit CommandEncoder(Device& d); + ~CommandEncoder(); - // Wait until kernels in the stream complete. - void synchronize(); + CommandEncoder(const CommandEncoder&) = delete; + CommandEncoder& operator=(const CommandEncoder&) = delete; - // Return a cuda stream for launching kernels. - cudaStream_t schedule_cuda_stream(); - - // Return the last cuda stream used. - cudaStream_t last_cuda_stream(); - - CommandEncoder& get_encoder(); - - Device& device() { - return device_; + CaptureContext capture_context() { + return CaptureContext{*this}; + } + ConcurrentContext concurrent_context() { + return ConcurrentContext{*this}; } + void set_input_array(const array& arr); + void set_output_array(const array& arr); + + template + void + add_kernel_node(F* func, dim3 grid_dim, dim3 block_dim, Params&&... params) { + constexpr size_t num = sizeof...(Params); + void* ptrs[num]; + size_t i = 0; + ([&](auto&& p) { ptrs[i++] = static_cast(&p); }( + std::forward(params)), + ...); + add_kernel_node((void*)func, grid_dim, block_dim, ptrs); + } + + void add_kernel_node( + CUfunction func, + dim3 grid_dim, + dim3 block_dim, + void** params); + + void + add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params); + + void add_temporary(const array& arr) { + temporaries_.push_back(arr.data_shared_ptr()); + } + + void add_completed_handler(std::function task); + void maybe_commit(); + void commit(); + + CudaStream& stream() { + return stream_; + } + + // Wait until kernels and completion handlers are finished + void synchronize(); + private: - Device& device_; + struct GraphNode { + cudaGraphNode_t node; + // K = kernel + // E = empty + // G = subgraph + char node_type; + std::string id; + }; + + void insert_graph_dependencies(GraphNode node); + void insert_graph_dependencies(std::vector nodes); + CudaStream stream_; - std::unique_ptr encoder_; + cudaGraph_t graph_; + Worker worker_; + char node_count_{0}; + char graph_node_count_{0}; + char empty_node_count_{0}; + bool in_concurrent_{false}; + std::vector from_nodes_; + std::vector to_nodes_; + std::string graph_key_; + std::vector concurrent_nodes_; + std::vector> temporaries_; + std::unordered_map graph_cache_; + std::vector active_deps_; + std::vector active_outputs_; + std::unordered_map node_map_; }; class Device { @@ -55,7 +122,7 @@ class Device { // Make this device the current cuda device, required by some cuda calls. void make_current(); - DeviceStream& get_stream(Stream s); + CommandEncoder& get_command_encoder(Stream s); int cuda_device() const { return device_; @@ -75,67 +142,10 @@ class Device { int compute_capability_major_; int compute_capability_minor_; cublasLtHandle_t lt_; - std::unordered_map streams_; -}; - -class CommandEncoder { - public: - explicit CommandEncoder(DeviceStream& stream); - - CommandEncoder(const CommandEncoder&) = delete; - CommandEncoder& operator=(const CommandEncoder&) = delete; - - void set_input_array(const array& arr) {} - void set_output_array(const array& arr) {} - - void add_temporary(const array& arr) { - temporaries_.push_back(arr.data_shared_ptr()); - } - - void add_completed_handler(std::function task); - void end_encoding(); - void commit(); - - // Schedule a cuda stream for |fun| to launch kernels, and check error - // afterwards. - template - void launch_kernel(F&& fun) { - launch_kernel(stream_.schedule_cuda_stream(), std::forward(fun)); - } - - template - void launch_kernel(cudaStream_t stream, F&& fun) { - device_.make_current(); - fun(stream); - check_cuda_error("kernel launch", cudaGetLastError()); - has_gpu_work_ = true; - } - - Device& device() { - return device_; - } - - DeviceStream& stream() { - return stream_; - } - - bool has_gpu_work() const { - return has_gpu_work_; - } - - // Wait until kernels and completion handlers are finished - void synchronize(); - - private: - Device& device_; - DeviceStream& stream_; - Worker worker_; - bool has_gpu_work_{false}; - std::vector> temporaries_; + std::unordered_map encoders_; }; Device& device(mlx::core::Device device); -DeviceStream& get_stream(Stream s); CommandEncoder& get_command_encoder(Stream s); // Return an execution policy that does not sync for result. diff --git a/mlx/backend/cuda/eval.cpp b/mlx/backend/cuda/eval.cpp index 21b019cd8..40beb12d2 100644 --- a/mlx/backend/cuda/eval.cpp +++ b/mlx/backend/cuda/eval.cpp @@ -37,22 +37,20 @@ void eval(array& arr) { } auto& encoder = cu::get_command_encoder(arr.primitive().stream()); - if (encoder.has_gpu_work()) { - // Keep used buffers alive until kernel finishes running. - std::unordered_set> buffers; - for (auto& in : arr.inputs()) { - buffers.insert(in.data_shared_ptr()); - } - for (auto& s : arr.siblings()) { - buffers.insert(s.data_shared_ptr()); - } - // Remove the output if it was donated to by an input. - if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { - buffers.erase(it); - } - encoder.add_completed_handler([buffers = std::move(buffers)]() {}); + // Keep used buffers alive until kernel finishes running. + std::unordered_set> buffers; + for (auto& in : arr.inputs()) { + buffers.insert(in.data_shared_ptr()); } - encoder.end_encoding(); + for (auto& s : arr.siblings()) { + buffers.insert(s.data_shared_ptr()); + } + // Remove the output if it was donated to by an input. + if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { + buffers.erase(it); + } + encoder.add_completed_handler([buffers = std::move(buffers)]() {}); + encoder.maybe_commit(); } void finalize(Stream s) { diff --git a/mlx/backend/cuda/event.cu b/mlx/backend/cuda/event.cu index 9fc5c641b..afa032a83 100644 --- a/mlx/backend/cuda/event.cu +++ b/mlx/backend/cuda/event.cu @@ -61,7 +61,9 @@ void CudaEvent::wait(Stream s) { if (s.device == mlx::core::Device::cpu) { scheduler::enqueue(s, [*this]() mutable { wait(); }); } else { - wait(cu::get_stream(s).last_cuda_stream()); + auto& enc = cu::get_command_encoder(s); + enc.commit(); + wait(enc.stream()); } } @@ -74,7 +76,9 @@ void CudaEvent::record(Stream s) { if (s.device == mlx::core::Device::cpu) { throw std::runtime_error("CudaEvent can not wait on cpu stream."); } else { - record(cu::get_stream(s).last_cuda_stream()); + auto& enc = cu::get_command_encoder(s); + enc.commit(); + record(enc.stream()); } } @@ -136,11 +140,9 @@ void SharedEvent::wait(Stream s, uint64_t value) { scheduler::enqueue(s, [*this, value]() mutable { wait(value); }); } else { auto& encoder = get_command_encoder(s); - encoder.launch_kernel( - encoder.stream().last_cuda_stream(), - [this, value](cudaStream_t stream) { wait(stream, value); }); + encoder.commit(); + wait(encoder.stream(), value); encoder.add_completed_handler([ac = ac_]() {}); - encoder.end_encoding(); } } @@ -162,11 +164,9 @@ void SharedEvent::signal(Stream s, uint64_t value) { scheduler::enqueue(s, [*this, value]() mutable { signal(stream, value); }); } else { auto& encoder = get_command_encoder(s); - encoder.launch_kernel( - encoder.stream().last_cuda_stream(), - [this, value](cudaStream_t stream) { signal(stream, value); }); + encoder.commit(); + signal(encoder.stream(), value); encoder.add_completed_handler([ac = ac_]() {}); - encoder.end_encoding(); } } diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp index 65a175fbd..4b03a604e 100644 --- a/mlx/backend/cuda/indexing.cpp +++ b/mlx/backend/cuda/indexing.cpp @@ -3,13 +3,16 @@ #include "mlx/backend/common/compiled.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/jit_module.h" +#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" #include "cuda_jit_sources.h" +#include #include +#include #include #include @@ -22,7 +25,7 @@ namespace { constexpr const char* g_scatter_ops[] = {"Max", "Min", "Sum", "Prod", "Assign"}; void append_indices_arg( - cu::JitModule& mod, + cu::KernelArgs& args, const std::vector& inputs, int nidx, int idx_ndim) { @@ -30,7 +33,7 @@ void append_indices_arg( for (int i = 0; i < nidx; ++i) { indices[i] = inputs[i + 1].data(); } - mod.append_arg(std::move(indices)); + args.append(std::move(indices)); std::vector indices_shape(nidx * idx_ndim); for (int i = 0; i < nidx; ++i) { std::copy_n( @@ -38,7 +41,7 @@ void append_indices_arg( idx_ndim, indices_shape.data() + i * idx_ndim); } - mod.append_arg(std::move(indices_shape)); + args.append(std::move(indices_shape)); std::vector indices_strides(nidx * idx_ndim); for (int i = 0; i < nidx; ++i) { std::copy_n( @@ -46,7 +49,7 @@ void append_indices_arg( idx_ndim, indices_strides.data() + i * idx_ndim); } - mod.append_arg(std::move(indices_strides)); + args.append(std::move(indices_strides)); } } // namespace @@ -94,20 +97,21 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { return std::make_pair(jit_source_gather, std::move(kernel_names)); }); - mod.append_arg(src); - mod.append_arg(out); + cu::KernelArgs args; + args.append(src); + args.append(out); if (large) { - mod.append_arg(out.size()); + args.append(out.size()); } else { - mod.append_arg(out.size()); + args.append(out.size()); } - mod.append_ndim_arg(src.shape()); - mod.append_ndim_arg(src.strides()); - mod.append_arg(src.ndim()); - mod.append_ndim_arg(slice_sizes_); - mod.append_arg(slice_size); - mod.append_arg(axes_); - append_indices_arg(mod, inputs, nidx, idx_ndim); + args.append_ndim(src.shape()); + args.append_ndim(src.strides()); + args.append(src.ndim()); + args.append_ndim(slice_sizes_); + args.append(slice_size); + args.append(axes_); + append_indices_arg(args, inputs, nidx, idx_ndim); std::string kernel_name = fmt::format( "mlx::core::cu::gather<{}, {}, {}, {}, {}>", @@ -122,9 +126,10 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(in); } encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - mod.launch_kernel(stream, kernel_name, out, large); - }); + + auto kernel = mod.get_kernel(kernel_name); + auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); } void Scatter::eval_gpu(const std::vector& inputs, array& out) { @@ -187,26 +192,27 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { return std::make_pair(jit_source_scatter, std::move(kernel_names)); }); - mod.append_arg(upd); - mod.append_arg(out); + cu::KernelArgs args; + args.append(upd); + args.append(out); if (large) { - mod.append_arg(upd.size()); + args.append(upd.size()); } else { - mod.append_arg(upd.size()); + args.append(upd.size()); } - mod.append_ndim_arg(upd.shape()); - mod.append_ndim_arg(upd.strides()); - mod.append_arg(upd.ndim()); + args.append_ndim(upd.shape()); + args.append_ndim(upd.strides()); + args.append(upd.ndim()); if (large) { - mod.append_arg(upd_post_idx_size); + args.append(upd_post_idx_size); } else { - mod.append_arg(upd_post_idx_size); + args.append(upd_post_idx_size); } - mod.append_ndim_arg(out.shape()); - mod.append_ndim_arg(out.strides()); - mod.append_arg(out.ndim()); - mod.append_arg(axes_); - append_indices_arg(mod, inputs, nidx, idx_ndim); + args.append_ndim(out.shape()); + args.append_ndim(out.strides()); + args.append(out.ndim()); + args.append(axes_); + append_indices_arg(args, inputs, nidx, idx_ndim); std::string kernel_name = fmt::format( "mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>", @@ -222,9 +228,9 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(in); } encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - mod.launch_kernel(stream, kernel_name, upd, large); - }); + auto kernel = mod.get_kernel(kernel_name); + auto [num_blocks, block_dims] = get_launch_args(kernel, upd, large); + encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); } void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { @@ -275,25 +281,26 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { } size_t idx_size_axis = idx.shape(axis_); - mod.append_arg(src); - mod.append_arg(idx); - mod.append_arg(out); + cu::KernelArgs args; + args.append(src); + args.append(idx); + args.append(out); if (large) { - mod.append_arg(idx_size_pre); - mod.append_arg(idx_size_axis); - mod.append_arg(idx_size_post); + args.append(idx_size_pre); + args.append(idx_size_axis); + args.append(idx_size_post); } else { - mod.append_arg(idx_size_pre); - mod.append_arg(idx_size_axis); - mod.append_arg(idx_size_post); + args.append(idx_size_pre); + args.append(idx_size_axis); + args.append(idx_size_post); } - mod.append_arg(remove_index(idx.shape(), axis_)); - mod.append_arg(remove_index(src.strides(), axis_)); - mod.append_arg(remove_index(idx.strides(), axis_)); - mod.append_arg(axis_); - mod.append_arg(src.shape(axis_)); - mod.append_arg(src.strides(axis_)); - mod.append_arg(idx.strides(axis_)); + args.append(remove_index(idx.shape(), axis_)); + args.append(remove_index(src.strides(), axis_)); + args.append(remove_index(idx.strides(), axis_)); + args.append(axis_); + args.append(src.shape(axis_)); + args.append(src.strides(axis_)); + args.append(idx.strides(axis_)); std::string kernel_name = fmt::format( "mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>", @@ -309,9 +316,9 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(in); } encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - mod.launch_kernel(stream, kernel_name, idx, large); - }); + auto kernel = mod.get_kernel(kernel_name); + auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large); + encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); } void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { @@ -377,25 +384,26 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { } size_t idx_size_axis = idx.shape(axis_); - mod.append_arg(upd); - mod.append_arg(idx); - mod.append_arg(out); + cu::KernelArgs args; + args.append(upd); + args.append(idx); + args.append(out); if (large) { - mod.append_arg(idx_size_pre); - mod.append_arg(idx_size_axis); - mod.append_arg(idx_size_post); + args.append(idx_size_pre); + args.append(idx_size_axis); + args.append(idx_size_post); } else { - mod.append_arg(idx_size_pre); - mod.append_arg(idx_size_axis); - mod.append_arg(idx_size_post); + args.append(idx_size_pre); + args.append(idx_size_axis); + args.append(idx_size_post); } - mod.append_arg(remove_index(idx.shape(), axis_)); - mod.append_arg(remove_index(upd.strides(), axis_)); - mod.append_arg(remove_index(idx.strides(), axis_)); - mod.append_arg(axis_); - mod.append_arg(out.shape(axis_)); - mod.append_arg(upd.strides(axis_)); - mod.append_arg(idx.strides(axis_)); + args.append(remove_index(idx.shape(), axis_)); + args.append(remove_index(upd.strides(), axis_)); + args.append(remove_index(idx.strides(), axis_)); + args.append(axis_); + args.append(out.shape(axis_)); + args.append(upd.strides(axis_)); + args.append(idx.strides(axis_)); std::string kernel_name = fmt::format( "mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>", @@ -412,9 +420,9 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(in); } encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - mod.launch_kernel(stream, kernel_name, idx, large); - }); + auto kernel = mod.get_kernel(kernel_name); + auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large); + encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); } } // namespace mlx::core diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index af8f7dc75..5bc56b25e 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -26,16 +26,6 @@ void check_nvrtc_error(const char* name, nvrtcResult err) { } } -#define CHECK_CU_ERROR(cmd) check_cu_error(#cmd, (cmd)) - -void check_cu_error(const char* name, CUresult err) { - if (err != CUDA_SUCCESS) { - const char* err_str = "Unknown error"; - cuGetErrorString(err, &err_str); - throw std::runtime_error(fmt::format("{} failed: {}", name, err_str)); - } -} - // Return the location of the CUDA toolkit. const std::string& cuda_home() { static std::string home = []() -> std::string { @@ -280,60 +270,13 @@ JitModule::JitModule( // Load kernels. for (const auto& [name, mangled] : ptx_kernels) { CUfunction kernel; - CHECK_CU_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str())); + CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str())); kernels_[name] = kernel; } } JitModule::~JitModule() { - CHECK_CU_ERROR(cuModuleUnload(module_)); -} - -void JitModule::launch_kernel( - CUstream stream, - const std::string& kernel_name, - const array& arr, - bool large, - int work_per_thread) { - CUfunction kernel = get_kernel(kernel_name); - size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread); - int _, block_dim; - CHECK_CU_ERROR( - cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0)); - if (block_dim > nthreads) { - block_dim = nthreads; - } - Dims num_blocks{1, 1, 1}; - if (large) { - num_blocks = - get_2d_grid_dims_common(arr.shape(), arr.strides(), work_per_thread); - std::get<0>(num_blocks) = - (std::get<0>(num_blocks) + block_dim - 1) / block_dim; - } else { - std::get<0>(num_blocks) = (nthreads + block_dim - 1) / block_dim; - } - launch_kernel(stream, kernel, num_blocks, Dims{block_dim, 1, 1}); -} - -void JitModule::launch_kernel( - CUstream stream, - CUfunction kernel, - Dims num_blocks, - Dims block_dims) { - CHECK_CU_ERROR(cuLaunchKernel( - kernel, - std::get<0>(num_blocks), - std::get<1>(num_blocks), - std::get<2>(num_blocks), - std::get<0>(block_dims), - std::get<1>(block_dims), - std::get<2>(block_dims), - 0, - stream, - args_.data(), - nullptr)); - args_.clear(); - storage_.clear(); + CHECK_CUDA_ERROR(cuModuleUnload(module_)); } CUfunction JitModule::get_kernel(const std::string& kernel_name) { @@ -345,10 +288,6 @@ CUfunction JitModule::get_kernel(const std::string& kernel_name) { return it->second; } -void JitModule::append_ptr_arg(const void* v) { - args_.push_back(const_cast(v)); -} - JitModule& get_jit_module( const mlx::core::Device& device, const std::string& name, diff --git a/mlx/backend/cuda/jit_module.h b/mlx/backend/cuda/jit_module.h index bbfaa45b0..57da7c87e 100644 --- a/mlx/backend/cuda/jit_module.h +++ b/mlx/backend/cuda/jit_module.h @@ -4,6 +4,7 @@ #include "mlx/array.h" #include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/config.h" #include @@ -23,72 +24,48 @@ using KernelBuilderResult = std::pair< /* kernel names */ std::vector>; using KernelBuilder = std::function; -class JitModule { - public: - JitModule( - Device& device, - const std::string& module_name, - const KernelBuilder& builder); - ~JitModule(); +struct KernelArgs { + void** args() { + return args_.data(); + } - JitModule(const JitModule&) = delete; - JitModule& operator=(const JitModule&) = delete; - - void append_arg(const array& a) { - append_arg(reinterpret_cast(a.data())); + void append(const array& a) { + append(reinterpret_cast(a.data())); } template - void append_arg(T val) { + void append(T val) { storage_.emplace_back(val); - append_ptr_arg(&storage_.back()); + append_ptr(&storage_.back()); } template - void append_arg(std::vector vec) { + void append(std::vector vec) { if (vec.empty()) { // The nullptr can not be used as arg, pass something not null. - append_arg(std::monostate{}); + append(std::monostate{}); } else { - append_ptr_arg(vec.data()); + append_ptr(vec.data()); storage_.emplace_back(std::move(vec)); } } // Make sure the arg is copied to an array with size of NDIM. template - void append_ndim_arg(const std::vector& vec) { + void append_ndim(std::vector vec) { if (vec.size() > NDIM) { throw std::runtime_error( fmt::format("ndim can not be larger than {}.", NDIM)); } - std::vector copied(NDIM); - std::copy(vec.begin(), vec.end(), copied.data()); - append_arg(std::move(copied)); + vec.resize(NDIM); + append(std::move(vec)); } - // Launch kernel with |kernel_name| that each thread works on - // |work_per_thread| elements of |arr|. - void launch_kernel( - CUstream stream, - const std::string& kernel_name, - const array& arr, - bool large, - int work_per_thread = 1); - - void launch_kernel( - CUstream stream, - CUfunction kernel, - Dims num_blocks, - Dims block_dims); - - CUfunction get_kernel(const std::string& kernel_name); + void append_ptr(const void* v) { + args_.push_back(const_cast(v)); + } private: - void append_ptr_arg(const void* v); - - CUmodule module_{nullptr}; - std::unordered_map kernels_; std::vector args_; // The cuLaunchKernel API requires passing pointers to arguments so store @@ -105,6 +82,23 @@ class JitModule { std::deque storage_; }; +class JitModule { + public: + JitModule( + Device& device, + const std::string& module_name, + const KernelBuilder& builder); + ~JitModule(); + + JitModule(const JitModule&) = delete; + JitModule& operator=(const JitModule&) = delete; + CUfunction get_kernel(const std::string& kernel_name); + + private: + CUmodule module_{nullptr}; + std::unordered_map kernels_; +}; + JitModule& get_jit_module( const mlx::core::Device& device, const std::string& name, diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index b0058b618..eeaf916c1 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -12,6 +12,7 @@ #include "mlx/backend/cuda/device/utils.cuh" #include +#include #include #include #include @@ -120,7 +121,13 @@ std::pair get_grid_and_block(int dim0, int dim1, int dim2); template inline uint max_occupancy_block_dim(T kernel) { int _, block_dim; - CHECK_CUDA_ERROR(cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel)); + if constexpr (std::is_same_v) { + CHECK_CUDA_ERROR( + cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0)); + } else { + CHECK_CUDA_ERROR( + cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel)); + } return block_dim; } diff --git a/mlx/backend/cuda/layer_norm.cu b/mlx/backend/cuda/layer_norm.cu index 9a9fbcb37..5fbf949d7 100644 --- a/mlx/backend/cuda/layer_norm.cu +++ b/mlx/backend/cuda/layer_norm.cu @@ -258,23 +258,23 @@ void LayerNorm::eval_gpu( encoder.set_input_array(w); encoder.set_input_array(b); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) { - constexpr uint32_t N_READS = 4; - dispatch_block_dim( - cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - using DataType = cuda_type_t; - auto kernel = cu::layer_norm; - kernel<<>>( - x.data(), - w.data(), - b.data(), - out.data(), - eps_, - axis_size, - w_stride, - b_stride); - }); + dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) { + constexpr uint32_t N_READS = 4; + dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::layer_norm; + encoder.add_kernel_node( + kernel, + n_rows, + block_dim(), + x.data(), + w.data(), + b.data(), + out.data(), + eps_, + axis_size, + w_stride, + b_stride); }); }); } @@ -289,21 +289,25 @@ void LayerNormVJP::eval_gpu( // Ensure row contiguity. We could relax this step by checking that the array // is contiguous (no broadcasts or holes) and that the input strides are the // same as the cotangent strides but for now this is simpler. - auto check_input = [&s](const array& x) -> std::pair { + auto check_input = [&s](const array& x, bool& copied) { if (x.flags().row_contiguous) { - return {x, false}; + copied = false; + return x; } + copied = true; array x_copy(x.shape(), x.dtype(), nullptr, {}); copy_gpu(x, x_copy, CopyType::General, s); - return {x_copy, true}; + return x_copy; }; bool donate_x = inputs[0].is_donatable(); bool donate_g = inputs[3].is_donatable(); - auto [x, copied] = check_input(inputs[0]); + bool copied; + auto x = check_input(inputs[0], copied); donate_x |= copied; const array& w = inputs[1]; const array& b = inputs[2]; - auto [g, g_copied] = check_input(inputs[3]); + bool g_copied; + auto g = check_input(inputs[3], g_copied); donate_g |= g_copied; array& gx = outputs[0]; array& gw = outputs[1]; @@ -334,8 +338,10 @@ void LayerNormVJP::eval_gpu( // gradient accumulators. array gw_temp = (has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w; + bool g_in_gw = false; if (has_w) { if (!g_in_gx && donate_g) { + g_in_gw = true; gw_temp.copy_shared_buffer(g); } else { gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); @@ -343,41 +349,47 @@ void LayerNormVJP::eval_gpu( } } - // Finish with the gradient for b in case we had a b. - if (gb.ndim() == 1 && gb.size() == axis_size) { + // The gradient for b in case we had a b. + bool has_gb = (gb.ndim() == 1 && gb.size() == axis_size); + if (has_gb) { ReductionPlan plan( ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); col_reduce(encoder, g, gb, Reduce::ReduceType::Sum, {0}, plan); } + // Insert dependency if `g` was donated + if ((g_in_gx || g_in_gw) && has_gb) { + encoder.set_input_array(gb); + } encoder.set_input_array(x); encoder.set_input_array(w); encoder.set_input_array(g); encoder.set_output_array(gx); encoder.set_output_array(gw_temp); - encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) { - dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) { - dispatch_bool(has_w, [&](auto has_w_constant) { - constexpr int N_READS = 4; - dispatch_block_dim( - cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - using DataType = cuda_type_t; - auto kernel = cu::layer_norm_vjp< - DataType, - has_w_constant.value, - block_dim(), - N_READS>; - kernel<<>>( - x.data(), - w.data(), - g.data(), - gx.data(), - gw_temp.data(), - eps_, - axis_size, - w_stride); - }); - }); + dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) { + dispatch_bool(has_w, [&](auto has_w_constant) { + constexpr int N_READS = 4; + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::layer_norm_vjp< + DataType, + has_w_constant.value, + block_dim(), + N_READS>; + encoder.add_kernel_node( + kernel, + n_rows, + block_dim(), + x.data(), + w.data(), + g.data(), + gx.data(), + gw_temp.data(), + eps_, + axis_size, + w_stride); + }); }); }); diff --git a/mlx/backend/cuda/logsumexp.cu b/mlx/backend/cuda/logsumexp.cu index 5d6bf437d..afc52826f 100644 --- a/mlx/backend/cuda/logsumexp.cu +++ b/mlx/backend/cuda/logsumexp.cu @@ -143,16 +143,18 @@ void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(in); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) { - constexpr int N_READS = 4; - dispatch_block_dim( - cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - using DataType = cuda_type_t; - auto kernel = cu::logsumexp; - kernel<<>>( - in.data(), out.data(), axis_size); - }); + dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) { + constexpr int N_READS = 4; + dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::logsumexp; + encoder.add_kernel_node( + kernel, + n_rows, + block_dim(), + in.data(), + out.data(), + axis_size); }); }); } diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index c32cecc03..e11c68b7d 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -42,7 +42,8 @@ class MatMul { int64_t ldb, int32_t batch_count, int64_t a_batch_stride, - int64_t b_batch_stride) { + int64_t b_batch_stride) + : handle_(device.lt_handle()) { heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; auto scale_type = dtype_to_cuda_type(dtype); @@ -147,7 +148,7 @@ class MatMul { if (heuristic_.state != CUBLAS_STATUS_SUCCESS) { int ret = 0; CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic( - encoder.device().lt_handle(), + handle_, matmul_desc_, a_desc_, b_desc_, @@ -172,25 +173,24 @@ class MatMul { workspace_ptr = workspace.data(); } - encoder.launch_kernel([&](cudaStream_t stream) { - CHECK_CUBLAS_ERROR(cublasLtMatmul( - encoder.device().lt_handle(), - matmul_desc_, - &alpha, - a, - a_desc_, - b, - b_desc_, - &beta, - c ? c : out, - c ? c_desc_ : out_desc_, - out, - out_desc_, - &heuristic_.algo, - workspace_ptr, - heuristic_.workspaceSize, - stream)); - }); + auto capture = encoder.capture_context(); + CHECK_CUBLAS_ERROR(cublasLtMatmul( + handle_, + matmul_desc_, + &alpha, + a, + a_desc_, + b, + b_desc_, + &beta, + c ? c : out, + c ? c_desc_ : out_desc_, + out, + out_desc_, + &heuristic_.algo, + workspace_ptr, + heuristic_.workspaceSize, + encoder.stream())); } private: @@ -259,6 +259,7 @@ class MatMul { return desc; } + cublasLtHandle_t handle_{nullptr}; cublasLtMatmulDesc_t matmul_desc_{nullptr}; cublasLtMatmulPreference_t pref_{nullptr}; cublasLtMatrixLayout_t a_desc_{nullptr}; @@ -273,7 +274,7 @@ class MatMul { namespace { std::tuple -check_transpose(std::vector& copies, const Stream& s, const array& arr) { +check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) { auto stx = arr.strides()[arr.ndim() - 2]; auto sty = arr.strides()[arr.ndim() - 1]; if (sty == 1 && stx == arr.shape(-1)) { @@ -283,7 +284,7 @@ check_transpose(std::vector& copies, const Stream& s, const array& arr) { } else { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); copy_gpu(arr, arr_copy, CopyType::General, s); - copies.push_back(arr_copy); + enc.add_temporary(arr_copy); return std::make_tuple(false, arr.shape(-1), arr_copy); } } @@ -317,13 +318,8 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { // Keep a vector with copies to be cleared in the completed buffer to release // the arrays - std::vector copies; - auto [a_transposed, lda, a] = check_transpose(copies, s, a_pre); - auto [b_transposed, ldb, b] = check_transpose(copies, s, b_pre); - - for (auto& temp : copies) { - encoder.add_temporary(temp); - } + auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); + auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); ///////////////////////////////////////////////////////////////////////////// // Check and collapse batch dimensions @@ -348,7 +344,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { // Invoke cublasLt cu::MatMul matmul( - encoder.device(), + cu::device(s.device), a.dtype(), a_transposed, M, @@ -373,6 +369,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); + auto concurrent = encoder.concurrent_context(); for (size_t i = 0; i < nbatch; ++i) { matmul.run( encoder, @@ -405,14 +402,9 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { // Keep a vector with copies to be cleared in the completed buffer to release // the arrays - std::vector copies; - auto [a_transposed, lda, a] = check_transpose(copies, s, a_pre); - auto [b_transposed, ldb, b] = check_transpose(copies, s, b_pre); - auto [c_transposed, ldc, c] = check_transpose(copies, s, c_pre); - - for (auto& temp : copies) { - encoder.add_temporary(temp); - } + auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); + auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); + auto [c_transposed, ldc, c] = check_transpose(encoder, s, c_pre); ///////////////////////////////////////////////////////////////////////////// // Check and collapse batch dimensions @@ -440,7 +432,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { // Invoke cublasLt cu::MatMul matmul( - encoder.device(), + cu::device(s.device), a.dtype(), a_transposed, M, @@ -478,6 +470,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); + auto concurrent = encoder.concurrent_context(); for (size_t i = 0; i < nbatch; ++i) { matmul.run( encoder, diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 715e5a232..18fa45a33 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -24,23 +24,21 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { if (out.size() == 0) { return; } - auto& s = stream(); - auto& encoder = cu::get_command_encoder(s); + auto& encoder = cu::get_command_encoder(stream()); encoder.set_output_array(out); - encoder.launch_kernel([&, this](cudaStream_t stream) { - dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { - using CTYPE = MLX_GET_TYPE(type_tag); - using OutType = cuda_type_t; - CTYPE step = - static_cast(start_ + step_) - static_cast(start_); - thrust::transform( - cu::thrust_policy(stream), - thrust::counting_iterator(0), - thrust::counting_iterator(out.data_size()), - thrust::device_pointer_cast(out.data()), - cu::Arange{ - static_cast(start_), static_cast(step)}); - }); + auto capture = encoder.capture_context(); + dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { + using CTYPE = MLX_GET_TYPE(type_tag); + using OutType = cuda_type_t; + CTYPE step = + static_cast(start_ + step_) - static_cast(start_); + thrust::transform( + cu::thrust_policy(encoder.stream()), + thrust::counting_iterator(0), + thrust::counting_iterator(out.data_size()), + thrust::device_pointer_cast(out.data()), + cu::Arange{ + static_cast(start_), static_cast(step)}); }); } diff --git a/mlx/backend/cuda/random.cu b/mlx/backend/cuda/random.cu index 0cb550d56..7221af356 100644 --- a/mlx/backend/cuda/random.cu +++ b/mlx/backend/cuda/random.cu @@ -156,34 +156,39 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(keys); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dim3 grid_dims{num_keys, half_size + odd}; - int64_t total = grid_dims.x * grid_dims.y; - int32_t threads_y = 1; - while ((total / threads_y) >= (1U << 31)) { - threads_y *= 2; - } - int32_t threads_x = cuda::ceil_div(total, threads_y); - auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1); - if (keys.flags().row_contiguous) { - cu::rbitsc<<>>( - keys.data(), - out.data(), - grid_dims, - odd, - bytes_per_key); - } else { - cu::rbits<<>>( - keys.data(), - out.data(), - grid_dims, - odd, - bytes_per_key, - keys.ndim(), - const_param(keys.shape()), - const_param(keys.strides())); - } - }); + dim3 grid_dims{num_keys, half_size + odd}; + int64_t total = grid_dims.x * grid_dims.y; + int32_t threads_y = 1; + while ((total / threads_y) >= (1U << 31)) { + threads_y *= 2; + } + int32_t threads_x = cuda::ceil_div(total, threads_y); + auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1); + auto& stream = encoder.stream(); + if (keys.flags().row_contiguous) { + encoder.add_kernel_node( + cu::rbitsc, + grid, + block, + keys.data(), + out.data(), + grid_dims, + odd, + bytes_per_key); + } else { + encoder.add_kernel_node( + cu::rbits, + grid, + block, + keys.data(), + out.data(), + grid_dims, + odd, + bytes_per_key, + keys.ndim(), + const_param(keys.shape()), + const_param(keys.strides())); + } } } // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index a6ccd5ae9..3419d61cb 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -110,19 +110,20 @@ void all_reduce( intermediate.set_data(allocator::malloc(intermediate.nbytes())); encoder.add_temporary(intermediate); encoder.set_output_array(intermediate); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(dt, [&](auto type_tag) { - dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { - using OP = MLX_GET_TYPE(reduce_type_tag); - using T = cuda_type_t; - using U = typename cu::ReduceResult::type; - auto kernel = cu::all_reduce; - kernel<<>>( - static_cast(indata), - intermediate.data(), - block_step, - insize); - }); + dispatch_all_types(dt, [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; + auto kernel = cu::all_reduce; + encoder.add_kernel_node( + kernel, + blocks, + threads, + static_cast(indata), + intermediate.data(), + block_step, + insize); }); }); @@ -135,16 +136,20 @@ void all_reduce( } encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(dt, [&](auto type_tag) { - dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { - using OP = MLX_GET_TYPE(reduce_type_tag); - using T = cuda_type_t; - using U = typename cu::ReduceResult::type; - auto kernel = cu::all_reduce; - kernel<<>>( - static_cast(indata), out.data(), block_step, insize); - }); + dispatch_all_types(dt, [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; + auto kernel = cu::all_reduce; + encoder.add_kernel_node( + kernel, + blocks, + threads, + static_cast(indata), + out.data(), + block_step, + insize); }); }); } diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index 78f6b93bc..910fa0379 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -214,26 +214,24 @@ void col_reduce_looped( encoder.set_input_array(in); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto type_tag) { - dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { - dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { - using OP = MLX_GET_TYPE(reduce_type_tag); - using T = cuda_type_t; - using U = typename cu::ReduceResult::type; + dispatch_all_types(in.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; + // Cub doesn't like const pointers for vectorized loads. (sigh) + T* indata = const_cast(in.data()); - // Cub doesn't like const pointers for vectorized loads. (sigh) - T* indata = const_cast(in.data()); - - constexpr int N_READS = 4; - constexpr int BM = 32; - constexpr int BN = 32; - dim3 grid = output_grid_for_col_reduce(out, args, BN); - int blocks = BM * BN / N_READS; - auto kernel = - cu::col_reduce_looped; - kernel<<>>(indata, out.data(), args); - }); + constexpr int N_READS = 4; + constexpr int BM = 32; + constexpr int BN = 32; + dim3 grid = output_grid_for_col_reduce(out, args, BN); + int blocks = BM * BN / N_READS; + auto kernel = + cu::col_reduce_looped; + encoder.add_kernel_node( + kernel, grid, blocks, indata, out.data(), args); }); }); }); diff --git a/mlx/backend/cuda/reduce/init_reduce.cu b/mlx/backend/cuda/reduce/init_reduce.cu index 296a4e611..649d80190 100644 --- a/mlx/backend/cuda/reduce/init_reduce.cu +++ b/mlx/backend/cuda/reduce/init_reduce.cu @@ -32,18 +32,16 @@ void init_reduce( } encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto type_tag) { - dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { - using OP = MLX_GET_TYPE(reduce_type_tag); - using T = cuda_type_t; - using U = typename cu::ReduceResult::type; - auto kernel = cu::init_reduce; - dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); - dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1); - grid.x = (grid.x + 1023) / 1024; - kernel<<>>(out.data(), out.size()); - }); + dispatch_all_types(in.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; + auto kernel = cu::init_reduce; + dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); + dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1); + grid.x = (grid.x + 1023) / 1024; + encoder.add_kernel_node(kernel, grid, block, out.data(), out.size()); }); }); } diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index deb4a2f91..e57f18668 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -245,34 +245,32 @@ void row_reduce_simple( // 2 passes. Something like 32 * out.size() and then do a warp reduce. encoder.set_input_array(in); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto type_tag) { - dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { - using OP = MLX_GET_TYPE(reduce_type_tag); - using T = cuda_type_t; - using U = typename cu::ReduceResult::type; + dispatch_all_types(in.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; - // Cub doesn't like const pointers for vectorized loads. (sigh) - T* indata = const_cast(in.data()); + // Cub doesn't like const pointers for vectorized loads. (sigh) + T* indata = const_cast(in.data()); - // Calculate the grid and block dims - size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS; - dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); - int threads = std::min(1024UL, reductions); - threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; - dim3 block(threads, 1, 1); + // Calculate the grid and block dims + size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS; + dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); + int threads = std::min(1024UL, reductions); + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + dim3 block(threads, 1, 1); - // Pick the kernel - auto kernel = cu::row_reduce_simple; - if (grid.x >= 1024) { - grid.x = (grid.x + 1) / 2; - kernel = cu::row_reduce_simple; - } + // Pick the kernel + auto kernel = cu::row_reduce_simple; + if (grid.x >= 1024) { + grid.x = (grid.x + 1) / 2; + kernel = cu::row_reduce_simple; + } - // Launch - kernel<<>>( - indata, out.data(), out.size(), plan.shape.back()); - }); + int size = plan.shape.back(); + encoder.add_kernel_node( + kernel, grid, block, indata, out.data(), out.size(), size); }); }); } @@ -293,43 +291,39 @@ void row_reduce_looped( encoder.set_input_array(in); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto type_tag) { - dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { - using OP = MLX_GET_TYPE(reduce_type_tag); - using T = cuda_type_t; - using U = typename cu::ReduceResult::type; + dispatch_all_types(in.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; + // Cub doesn't like const pointers for vectorized loads. (sigh) + T* indata = const_cast(in.data()); - // Cub doesn't like const pointers for vectorized loads. (sigh) - T* indata = const_cast(in.data()); + // Calculate the grid and block dims + args.sort_access_pattern(in, axes); + dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); + size_t reductions = (args.row_size + N_READS - 1) / N_READS; + int threads = std::min(1024UL, reductions); + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + dim3 block(threads, 1, 1); - // Calculate the grid and block dims - args.sort_access_pattern(in, axes); - dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); - size_t reductions = (args.row_size + N_READS - 1) / N_READS; - int threads = std::min(1024UL, reductions); - threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; - dim3 block(threads, 1, 1); - - // Pick the kernel - auto kernel = cu::row_reduce_looped; - dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { - dispatch_block_dim(threads, [&](auto threads_constant) { - kernel = cu::row_reduce_looped< - T, - U, - OP, - reduce_ndim.value, - threads_constant.value, - N_READS>; - block.x = threads_constant.value; - }); + // Pick the kernel + auto kernel = cu::row_reduce_looped; + dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { + dispatch_block_dim(threads, [&](auto threads_constant) { + kernel = cu::row_reduce_looped< + T, + U, + OP, + reduce_ndim.value, + threads_constant.value, + N_READS>; + block.x = threads_constant.value; }); - - // Launch - kernel<<>>( - indata, out.data(), out.size(), args); }); + + encoder.add_kernel_node( + kernel, grid, block, indata, out.data(), out.size(), args); }); }); } diff --git a/mlx/backend/cuda/rms_norm.cu b/mlx/backend/cuda/rms_norm.cu index fc8f4f490..5ee1d3386 100644 --- a/mlx/backend/cuda/rms_norm.cu +++ b/mlx/backend/cuda/rms_norm.cu @@ -224,21 +224,21 @@ void RMSNorm::eval_gpu( encoder.set_input_array(x); encoder.set_input_array(w); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) { - constexpr uint32_t N_READS = 4; - dispatch_block_dim( - cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - using DataType = cuda_type_t; - auto kernel = cu::rms_norm; - kernel<<>>( - x.data(), - w.data(), - out.data(), - eps_, - axis_size, - w_stride); - }); + dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) { + constexpr uint32_t N_READS = 4; + dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::rms_norm; + encoder.add_kernel_node( + kernel, + n_rows, + block_dim(), + x.data(), + w.data(), + out.data(), + eps_, + axis_size, + w_stride); }); }); } @@ -253,20 +253,24 @@ void RMSNormVJP::eval_gpu( // Ensure row contiguity. We could relax this step by checking that the array // is contiguous (no broadcasts or holes) and that the input strides are the // same as the cotangent strides but for now this is simpler. - auto check_input = [&s](const array& x) -> std::pair { + auto check_input = [&s](const array& x, bool& copied) { if (x.flags().row_contiguous) { - return {x, false}; + copied = false; + return x; } + copied = true; array x_copy(x.shape(), x.dtype(), nullptr, {}); copy_gpu(x, x_copy, CopyType::General, s); - return {x_copy, true}; + return x_copy; }; bool donate_x = inputs[0].is_donatable(); bool donate_g = inputs[2].is_donatable(); - auto [x, copied] = check_input(inputs[0]); + bool copied; + auto x = check_input(inputs[0], copied); donate_x |= copied; const array& w = inputs[1]; - auto [g, g_copied] = check_input(inputs[2]); + bool g_copied; + auto g = check_input(inputs[2], g_copied); donate_g |= g_copied; array& gx = outputs[0]; array& gw = outputs[1]; @@ -310,30 +314,31 @@ void RMSNormVJP::eval_gpu( encoder.set_input_array(g); encoder.set_output_array(gx); encoder.set_output_array(gw_temp); - encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) { - dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) { - dispatch_bool(has_w, [&](auto has_w_constant) { - constexpr int N_READS = 4; - dispatch_block_dim( - cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - using DataType = cuda_type_t; - constexpr int N_READS = 4; - auto kernel = cu::rms_norm_vjp< - DataType, - has_w_constant.value, - block_dim(), - N_READS>; - kernel<<>>( - x.data(), - w.data(), - g.data(), - gx.data(), - gw_temp.data(), - eps_, - axis_size, - w_stride); - }); - }); + dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) { + dispatch_bool(has_w, [&](auto has_w_constant) { + constexpr int N_READS = 4; + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + constexpr int N_READS = 4; + auto kernel = cu::rms_norm_vjp< + DataType, + has_w_constant.value, + block_dim(), + N_READS>; + encoder.add_kernel_node( + kernel, + n_rows, + block_dim(), + x.data(), + w.data(), + g.data(), + gx.data(), + gw_temp.data(), + eps_, + axis_size, + w_stride); + }); }); }); diff --git a/mlx/backend/cuda/rope.cu b/mlx/backend/cuda/rope.cu index bb9618fc4..517cddfe0 100644 --- a/mlx/backend/cuda/rope.cu +++ b/mlx/backend/cuda/rope.cu @@ -308,76 +308,89 @@ void RoPE::eval_gpu( auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(donated ? out : in); encoder.set_input_array(offset); + if (with_freqs) { + encoder.set_input_array(inputs[2]); + } encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_float_types(out.dtype(), "rope", [&](auto type_tag) { - dispatch_bool(traditional_, [&](auto traditional) { - dispatch_bool(forward_, [&](auto forward) { - using DataType = cuda_type_t; - if (single && !with_freqs) { - auto kernel = - cu::rope_single; - uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); - auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); - kernel<<>>( - (donated ? out : in).data(), - out.data(), - offset.data(), - scale_, - std::log2(base_), - mat_size, - dims); - } else if (single) { - auto kernel = cu:: - rope_single_freqs; - uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); - auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); - kernel<<>>( - (donated ? out : in).data(), - out.data(), - offset.data(), - inputs[2].data(), - scale_, - mat_size, - dims, - inputs[2].strides(0)); - } else if (with_freqs) { - auto kernel = - cu::rope_freqs; - uint3 dims = - make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); - dims.z = (dims.z + 3) / 4; - auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); - kernel<<>>( - (donated ? out : in).data(), - out.data(), - offset.data(), - inputs[2].data(), - scale_, - std::log2(base_), - strides, - out_strides, - in.size() / mat_size, - dims, - inputs[2].strides(0)); - } else { - auto kernel = cu::rope; - uint3 dims = - make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); - dims.z = (dims.z + 3) / 4; - auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); - kernel<<>>( - (donated ? out : in).data(), - out.data(), - offset.data(), - scale_, - std::log2(base_), - strides, - out_strides, - in.size() / mat_size, - dims); - } - }); + dispatch_float_types(out.dtype(), "rope", [&](auto type_tag) { + dispatch_bool(traditional_, [&](auto traditional) { + dispatch_bool(forward_, [&](auto forward) { + using DataType = cuda_type_t; + if (single && !with_freqs) { + auto kernel = + cu::rope_single; + uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); + auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); + encoder.add_kernel_node( + kernel, + grid, + block, + (donated ? out : in).data(), + out.data(), + offset.data(), + scale_, + std::log2(base_), + mat_size, + dims); + } else if (single) { + auto kernel = + cu::rope_single_freqs; + uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); + auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); + encoder.add_kernel_node( + kernel, + grid, + block, + (donated ? out : in).data(), + out.data(), + offset.data(), + inputs[2].data(), + scale_, + mat_size, + dims, + inputs[2].strides(0)); + } else if (with_freqs) { + auto kernel = + cu::rope_freqs; + uint3 dims = + make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); + dims.z = (dims.z + 3) / 4; + auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); + encoder.add_kernel_node( + kernel, + grid, + block, + (donated ? out : in).data(), + out.data(), + offset.data(), + inputs[2].data(), + scale_, + std::log2(base_), + strides, + out_strides, + in.size() / mat_size, + dims, + inputs[2].strides(0)); + } else { + auto kernel = cu::rope; + uint3 dims = + make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); + dims.z = (dims.z + 3) / 4; + auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); + encoder.add_kernel_node( + kernel, + grid, + block, + (donated ? out : in).data(), + out.data(), + offset.data(), + scale_, + std::log2(base_), + strides, + out_strides, + in.size() / mat_size, + dims); + } }); }); }); diff --git a/mlx/backend/cuda/softmax.cu b/mlx/backend/cuda/softmax.cu index af9ddf214..fd807bd8d 100644 --- a/mlx/backend/cuda/softmax.cu +++ b/mlx/backend/cuda/softmax.cu @@ -141,19 +141,21 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(in); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) { - constexpr int N_READS = 4; - dispatch_block_dim( - cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - using DataType = cuda_type_t; - auto kernel = cu::softmax; - if (precise) { - kernel = cu::softmax; - } - kernel<<>>( - in.data(), out.data(), axis_size); - }); + dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) { + constexpr int N_READS = 4; + dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::softmax; + if (precise) { + kernel = cu::softmax; + } + encoder.add_kernel_node( + kernel, + n_rows, + block_dim(), + in.data(), + out.data(), + axis_size); }); }); } diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index 2c5599bed..379c55706 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -50,32 +50,6 @@ array swapaxes_in_eval(const array& in, int axis1, int axis2) { return out; } -template -void segmented_sort_pairs(cu::CommandEncoder& encoder, Args&&... args) { - // Allocate temporary storage. - size_t size; - CHECK_CUDA_ERROR( - cub::DeviceSegmentedSort::StableSortPairs(nullptr, size, args...)); - array temp(allocator::malloc(size), {static_cast(size)}, uint8); - encoder.add_temporary(temp); - // Run op. - CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs( - temp.data(), size, args...)); -} - -template -void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) { - // Allocate temporary storage. - size_t size; - CHECK_CUDA_ERROR( - cub::DeviceSegmentedSort::StableSortKeys(nullptr, size, args...)); - array temp(allocator::malloc(size), {static_cast(size)}, uint8); - encoder.add_temporary(temp); - // Run op. - CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys( - temp.data(), size, args...)); -} - struct OffsetTransform { int nsort; @@ -113,57 +87,94 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { encoder.set_input_array(in); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto type_tag) { - using CTYPE = MLX_GET_TYPE(type_tag); - if constexpr (!std::is_same_v) { - using Type = cuda_type_t; - auto offsets = thrust::make_transform_iterator( - thrust::make_counting_iterator(0), OffsetTransform{nsort}); - if (argsort) { - // Indices in the sorted dimension. - array indices( - allocator::malloc(out.nbytes()), in.shape(), out.dtype()); - encoder.add_temporary(indices); - thrust::transform( - cu::thrust_policy(stream), - thrust::counting_iterator(0), - thrust::counting_iterator(indices.data_size()), - thrust::device_pointer_cast(indices.data()), - ModOp{static_cast(nsort)}); + dispatch_all_types(in.dtype(), [&](auto type_tag) { + using CTYPE = MLX_GET_TYPE(type_tag); + auto& stream = encoder.stream(); + if constexpr (!std::is_same_v) { + using Type = cuda_type_t; + auto offsets = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), OffsetTransform{nsort}); + if (argsort) { + // Indices in the sorted dimension. + array indices(allocator::malloc(out.nbytes()), in.shape(), out.dtype()); + encoder.add_temporary(indices); - // In argsort though we don't need the result of sorted values, the - // API requires us to provide an array to store it. - array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype()); - encoder.add_temporary(discard); + // In argsort though we don't need the result of sorted values, the + // API requires us to provide an array to store it. + array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype()); + encoder.add_temporary(discard); - segmented_sort_pairs( - encoder, - in.data(), - discard.data(), - indices.data(), - out.data(), - in.data_size(), - in.data_size() / nsort, - offsets, - offsets + 1, - stream); - } else { - segmented_sort( - encoder, - in.data(), - out.data(), - in.data_size(), - in.data_size() / nsort, - offsets, - offsets + 1, - stream); - } + size_t size; + CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs( + nullptr, + size, + in.data(), + discard.data(), + indices.data(), + out.data(), + in.data_size(), + in.data_size() / nsort, + offsets, + offsets + 1, + stream)); + + array temp(allocator::malloc(size), {static_cast(size)}, uint8); + encoder.add_temporary(temp); + + // Start capturing after allocations + auto capture = encoder.capture_context(); + thrust::transform( + cu::thrust_policy(stream), + thrust::counting_iterator(0), + thrust::counting_iterator(indices.data_size()), + thrust::device_pointer_cast(indices.data()), + ModOp{static_cast(nsort)}); + + CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs( + temp.data(), + size, + in.data(), + discard.data(), + indices.data(), + out.data(), + in.data_size(), + in.data_size() / nsort, + offsets, + offsets + 1, + stream)); } else { - throw std::runtime_error( - "CUDA backend does not support sorting complex numbers"); + size_t size; + CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys( + nullptr, + size, + in.data(), + out.data(), + in.data_size(), + in.data_size() / nsort, + offsets, + offsets + 1, + stream)); + + array temp(allocator::malloc(size), {static_cast(size)}, uint8); + encoder.add_temporary(temp); + + // Start capturing after allocations + auto capture = encoder.capture_context(); + CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys( + temp.data(), + size, + in.data(), + out.data(), + in.data_size(), + in.data_size() / nsort, + offsets, + offsets + 1, + stream)); } - }); + } else { + throw std::runtime_error( + "CUDA backend does not support sorting complex numbers"); + } }); if (!is_segmented_sort) { diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index 1d6535100..aa6523f27 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -91,73 +91,80 @@ void ternary_op_gpu_inplace( encoder.set_input_array(b); encoder.set_input_array(c); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(out.dtype(), [&](auto type_tag) { - using DType = cuda_type_t; + dispatch_all_types(out.dtype(), [&](auto type_tag) { + using DType = cuda_type_t; - auto topt = get_ternary_op_type(a, b, c); - if (topt == TernaryOpType::General) { - dispatch_bool( - a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || - c.data_size() > INT32_MAX || out.data_size() > INT32_MAX, - [&](auto large) { - using IdxT = std::conditional_t; - Shape shape; - std::vector strides; - std::tie(shape, strides) = collapse_contiguous_dims(a, b, c, out); - auto& a_strides = strides[0]; - auto& b_strides = strides[1]; - auto& c_strides = strides[2]; - int ndim = shape.size(); - if (ndim <= 3) { - dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto kernel = - cu::ternary_g_nd; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large()); - kernel<<>>( - a.data(), - b.data(), - c.data(), - out.data(), - out.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides), - const_param(c_strides)); - }); - } else { - auto kernel = cu::ternary_g; + auto topt = get_ternary_op_type(a, b, c); + if (topt == TernaryOpType::General) { + dispatch_bool( + a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || + c.data_size() > INT32_MAX || out.data_size() > INT32_MAX, + [&](auto large) { + using IdxT = std::conditional_t; + Shape shape; + std::vector strides; + std::tie(shape, strides) = collapse_contiguous_dims(a, b, c, out); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + auto& c_strides = strides[2]; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = + cu::ternary_g_nd; auto [num_blocks, block_dims] = get_launch_args(kernel, out, large()); - kernel<<>>( + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, a.data(), b.data(), c.data(), out.data(), - out.data_size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides), - const_param(c_strides), - ndim); - } - }); - } else { - dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { - using IdxT = std::conditional_t; - auto kernel = cu::ternary_v; - auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), large()); - kernel<<>>( - a.data(), - b.data(), - c.data(), - out.data(), - out.data_size()); - }); - } - }); + out.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + const_param(c_strides)); + }); + } else { + auto kernel = cu::ternary_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + a.data(), + b.data(), + c.data(), + out.data(), + out.data_size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + const_param(c_strides), + ndim); + } + }); + } else { + dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { + using IdxT = std::conditional_t; + auto kernel = cu::ternary_v; + auto [num_blocks, block_dims] = get_launch_args( + kernel, out.data_size(), out.shape(), out.strides(), large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + a.data(), + b.data(), + c.data(), + out.data(), + out.data_size()); + }); + } }); } diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index 74251d1f6..3f1a62d24 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -9,14 +9,38 @@ #include "mlx/dtype_utils.h" #include "mlx/primitives.h" +#include #include -#include -#include namespace mlx::core { namespace cu { +namespace cg = cooperative_groups; + +template +__global__ void unary_v(const In* in, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = Op{}(in[index]); + } +} + +template +__global__ void unary_g( + const In* in, + Out* out, + IdxT size, + const __grid_constant__ Shape shape, + const __grid_constant__ Strides strides, + int ndim) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto idx = elem_to_loc_4d(index, shape.data(), strides.data(), ndim); + out[index] = Op{}(in[idx]); + } +} + template constexpr bool supports_unary_op() { if (std::is_same_v || std::is_same_v || @@ -71,38 +95,61 @@ void unary_op_gpu_inplace( if (in.size() == 0) { return; } + bool contig = in.flags().contiguous; + bool large; + if (!contig) { + large = in.data_size() > INT32_MAX || out.size() > INT32_MAX; + } else { + large = in.data_size() > UINT32_MAX; + } auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(in); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto in_type_tag) { - dispatch_all_types(out.dtype(), [&](auto out_type_tag) { - using CTYPE_IN = MLX_GET_TYPE(in_type_tag); - using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); - if constexpr (cu::supports_unary_op()) { + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + using CTYPE_IN = MLX_GET_TYPE(in_type_tag); + using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); + if constexpr (cu::supports_unary_op()) { + dispatch_bool(large, [&](auto large) { + using IdxT = std::conditional_t; using InType = cuda_type_t; using OutType = cuda_type_t; - auto policy = cu::thrust_policy(stream); - auto in_ptr = thrust::device_pointer_cast(in.data()); - auto out_ptr = thrust::device_pointer_cast(out.data()); - if (in.flags().contiguous) { - thrust::transform( - policy, in_ptr, in_ptr + in.data_size(), out_ptr, Op()); + using IdxT = std::conditional_t; + if (contig) { + auto kernel = cu::unary_v; + auto [num_blocks, block_dims] = get_launch_args( + kernel, out.data_size(), out.shape(), out.strides(), large); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + in.data(), + out.data(), + out.data_size()); } else { auto [shape, strides] = collapse_contiguous_dims(in); - auto [in_begin, in_end] = cu::make_general_iterators( - in_ptr, in.size(), shape, strides); - thrust::transform(policy, in_begin, in_end, out_ptr, Op()); + auto kernel = cu::unary_g; + auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + in.data(), + out.data(), + out.data_size(), + const_param(shape), + const_param(strides), + shape.size()); } - } else { - throw std::runtime_error(fmt::format( - "Can not do unary op {} on input of {} with output of {}.", - op, - dtype_to_string(in.dtype()), - dtype_to_string(out.dtype()))); - } - }); + }); + } else { + throw std::runtime_error(fmt::format( + "Can not do unary op {} on input of {} with output of {}.", + op, + dtype_to_string(in.dtype()), + dtype_to_string(out.dtype()))); + } }); }); } diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp index 35731f6eb..cc05428a8 100644 --- a/mlx/backend/cuda/utils.cpp +++ b/mlx/backend/cuda/utils.cpp @@ -24,6 +24,14 @@ void check_cuda_error(const char* name, cudaError_t err) { } } +void check_cuda_error(const char* name, CUresult err) { + if (err != CUDA_SUCCESS) { + const char* err_str = "Unknown error"; + cuGetErrorString(err, &err_str); + throw std::runtime_error(fmt::format("{} failed: {}", name, err_str)); + } +} + const char* dtype_to_cuda_type(const Dtype& dtype) { switch (dtype) { case bool_: diff --git a/mlx/backend/cuda/utils.h b/mlx/backend/cuda/utils.h index 6d98cdcd5..bfb02c5b6 100644 --- a/mlx/backend/cuda/utils.h +++ b/mlx/backend/cuda/utils.h @@ -4,6 +4,7 @@ #pragma once +#include #include namespace mlx::core { @@ -33,6 +34,7 @@ class CudaStream { // Throw exception if the cuda API does not succeed. void check_cuda_error(const char* name, cudaError_t err); +void check_cuda_error(const char* name, CUresult err); // The macro version that prints the command that failed. #define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd)) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 144f9a880..ff3208e1e 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -688,7 +688,7 @@ array solve(const array& a, const array& b, StreamOrDevice s /* = {} */) { perm = expand_dims(perm, -1, s); take_axis -= 1; } - auto pb = take_along_axis(b, perm, take_axis); + auto pb = take_along_axis(b, perm, take_axis, s); auto y = solve_triangular(luf[1], pb, /* upper = */ false, s); return solve_triangular(luf[2], y, /* upper = */ true, s); } diff --git a/python/tests/test_load.py b/python/tests/test_load.py index 35f7016c5..840d3b471 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -391,9 +391,11 @@ class TestLoad(mlx_tests.MLXTestCase): scale = mx.array(2.0) y = mx.load(save_file) mx.eval(y) + mx.synchronize() load_only = mx.get_peak_memory() y = mx.load(save_file) * scale mx.eval(y) + mx.synchronize() load_with_binary = mx.get_peak_memory() self.assertEqual(load_only, load_with_binary)