diff --git a/mlx/backend/common/matmul.h b/mlx/backend/common/matmul.h index 2e0261a30..2faf256d1 100644 --- a/mlx/backend/common/matmul.h +++ b/mlx/backend/common/matmul.h @@ -12,16 +12,11 @@ namespace mlx::core { inline std::tuple collapse_batches( const array& a, const array& b) { - // Get and check the shape for the batched dims - Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; - Shape B_bshape{b.shape().begin(), b.shape().end() - 2}; - if (A_bshape != B_bshape) { - std::ostringstream msg; - msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A " - << a.shape() << ", B " << b.shape() << "."; - throw std::runtime_error(msg.str()); + if (a.ndim() == 2) { + return {{1}, {0}, {0}}; } + Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; Strides A_bstride{a.strides().begin(), a.strides().end() - 2}; Strides B_bstride{b.strides().begin(), b.strides().end() - 2}; @@ -42,17 +37,11 @@ inline std::tuple collapse_batches( inline std::tuple collapse_batches(const array& a, const array& b, const array& c) { - // Get and check the shape for the batched dims - Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; - Shape B_bshape{b.shape().begin(), b.shape().end() - 2}; - Shape C_bshape{c.shape().begin(), c.shape().end() - 2}; - if (A_bshape != B_bshape || A_bshape != C_bshape) { - std::ostringstream msg; - msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A " - << a.shape() << ", B " << b.shape() << ", B " << c.shape() << "."; - throw std::runtime_error(msg.str()); + if (a.ndim() == 2) { + return {{1}, {0}, {0}, {0}}; } + Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; Strides A_bstride{a.strides().begin(), a.strides().end() - 2}; Strides B_bstride{b.strides().begin(), b.strides().end() - 2}; Strides C_bstride{c.strides().begin(), c.strides().end() - 2}; diff --git a/mlx/backend/cuda/arg_reduce.cu b/mlx/backend/cuda/arg_reduce.cu index 90f8561c1..ad942a406 100644 --- a/mlx/backend/cuda/arg_reduce.cu +++ b/mlx/backend/cuda/arg_reduce.cu @@ -151,30 +151,29 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(in); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) { - using T = cuda_type_t; - constexpr uint32_t N_READS = 4; - dispatch_block_dim( - cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides()); - auto kernel = - cu::arg_reduce_general, block_dim(), N_READS>; - if (reduce_type_ == ArgReduce::ArgMin) { - kernel = cu:: - arg_reduce_general, block_dim(), N_READS>; - } - kernel<<>>( - in.data(), - out.data(), - out.size(), - const_param(shape), - const_param(in_strides), - const_param(out_strides), - ndim, - axis_stride, - axis_size); - }); + dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) { + using T = cuda_type_t; + constexpr uint32_t N_READS = 4; + dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides()); + auto kernel = + cu::arg_reduce_general, block_dim(), N_READS>; + if (reduce_type_ == ArgReduce::ArgMin) { + kernel = cu::arg_reduce_general, block_dim(), N_READS>; + } + encoder.add_kernel_node( + kernel, + num_blocks, + block_dim(), + in.data(), + out.data(), + out.size(), + const_param(shape), + const_param(in_strides), + const_param(out_strides), + ndim, + axis_stride, + axis_size); }); }); } diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index 8e476d30f..d9b9fd8af 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -139,90 +139,92 @@ void binary_op_gpu_inplace( encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(a.dtype(), [&](auto in_type_tag) { - dispatch_all_types(out.dtype(), [&](auto out_type_tag) { - using CTYPE_IN = MLX_GET_TYPE(in_type_tag); - using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); - if constexpr (cu::supports_binary_op()) { - using InType = cuda_type_t; - using OutType = cuda_type_t; - auto bopt = get_binary_op_type(a, b); - if (bopt == BinaryOpType::General) { - dispatch_bool( - a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || - out.data_size() > INT32_MAX, - [&](auto large) { - using IdxT = std::conditional_t; - Shape shape; - std::vector strides; - std::tie(shape, strides) = - collapse_contiguous_dims(a, b, out); - auto& a_strides = strides[0]; - auto& b_strides = strides[1]; - int ndim = shape.size(); - if (ndim <= 3) { - dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto kernel = cu::binary_g_nd< - Op, - InType, - OutType, - IdxT, - dims_constant()>; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large()); - kernel<<>>( - a.data(), - b.data(), - out.data(), - out.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides)); - }); - } else { - auto kernel = cu::binary_g; + dispatch_all_types(a.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + using CTYPE_IN = MLX_GET_TYPE(in_type_tag); + using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); + if constexpr (cu::supports_binary_op()) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + auto bopt = get_binary_op_type(a, b); + if (bopt == BinaryOpType::General) { + dispatch_bool( + a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || + out.data_size() > INT32_MAX, + [&](auto large) { + using IdxT = std::conditional_t; + Shape shape; + std::vector strides; + std::tie(shape, strides) = collapse_contiguous_dims(a, b, out); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = cu:: + binary_g_nd; auto [num_blocks, block_dims] = get_launch_args(kernel, out, large()); - kernel<<>>( + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, a.data(), b.data(), out.data(), out.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides), - ndim); - } - }); - } else { - dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { - using IdxT = std::conditional_t; - auto kernel = cu::binary_ss; - if (bopt == BinaryOpType::ScalarVector) { - kernel = cu::binary_sv; - } else if (bopt == BinaryOpType::VectorScalar) { - kernel = cu::binary_vs; - } else if (bopt == BinaryOpType::VectorVector) { - kernel = cu::binary_vv; - } - auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), large()); - kernel<<>>( - a.data(), - b.data(), - out.data(), - out.data_size()); - }); - } + const_param(shape), + const_param(a_strides), + const_param(b_strides)); + }); + } else { + auto kernel = cu::binary_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + a.data(), + b.data(), + out.data(), + out.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + ndim); + } + }); } else { - throw std::runtime_error(fmt::format( - "Can not do binary op {} on inputs of {} with result of {}.", - op, - dtype_to_string(a.dtype()), - dtype_to_string(out.dtype()))); + dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { + using IdxT = std::conditional_t; + auto kernel = cu::binary_ss; + if (bopt == BinaryOpType::ScalarVector) { + kernel = cu::binary_sv; + } else if (bopt == BinaryOpType::VectorScalar) { + kernel = cu::binary_vs; + } else if (bopt == BinaryOpType::VectorVector) { + kernel = cu::binary_vv; + } + auto [num_blocks, block_dims] = get_launch_args( + kernel, out.data_size(), out.shape(), out.strides(), large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + a.data(), + b.data(), + out.data(), + out.data_size()); + }); } - }); + } else { + throw std::runtime_error(fmt::format( + "Can not do binary op {} on inputs of {} with result of {}.", + op, + dtype_to_string(a.dtype()), + dtype_to_string(out.dtype()))); + } }); }); } diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index 0a68e5f1d..9582b0378 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -137,98 +137,101 @@ void binary_op_gpu_inplace( encoder.set_input_array(b); encoder.set_output_array(out_a); encoder.set_output_array(out_b); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(a.dtype(), [&](auto in_type_tag) { - dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) { - using CTYPE_IN = MLX_GET_TYPE(in_type_tag); - using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); - if constexpr (cu::supports_binary_op()) { - using InType = cuda_type_t; - using OutType = cuda_type_t; + dispatch_all_types(a.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) { + using CTYPE_IN = MLX_GET_TYPE(in_type_tag); + using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); + if constexpr (cu::supports_binary_op()) { + using InType = cuda_type_t; + using OutType = cuda_type_t; - auto bopt = get_binary_op_type(a, b); - if (bopt == BinaryOpType::General) { - dispatch_bool( - a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || - out_a.data_size() > INT32_MAX, - [&](auto large) { - using IdxT = std::conditional_t; - Shape shape; - std::vector strides; - std::tie(shape, strides) = - collapse_contiguous_dims(a, b, out_a); - auto& a_strides = strides[0]; - auto& b_strides = strides[1]; - int ndim = shape.size(); - if (ndim <= 3) { - dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto kernel = cu::binary_g_nd< - Op, - InType, - OutType, - IdxT, - dims_constant()>; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out_a, large()); - kernel<<>>( - a.data(), - b.data(), - out_a.data(), - out_b.data(), - out_a.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides)); - }); - } else { - auto kernel = cu::binary_g; + auto bopt = get_binary_op_type(a, b); + if (bopt == BinaryOpType::General) { + dispatch_bool( + a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || + out_a.data_size() > INT32_MAX, + [&](auto large) { + using IdxT = std::conditional_t; + Shape shape; + std::vector strides; + std::tie(shape, strides) = + collapse_contiguous_dims(a, b, out_a); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = cu:: + binary_g_nd; auto [num_blocks, block_dims] = get_launch_args(kernel, out_a, large()); - kernel<<>>( + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, a.data(), b.data(), out_a.data(), out_b.data(), out_a.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides), - ndim); - } - }); - } else { - dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) { - using IdxT = std::conditional_t; - auto kernel = cu::binary_ss; - if (bopt == BinaryOpType::ScalarVector) { - kernel = cu::binary_sv; - } else if (bopt == BinaryOpType::VectorScalar) { - kernel = cu::binary_vs; - } else if (bopt == BinaryOpType::VectorVector) { - kernel = cu::binary_vv; - } - auto [num_blocks, block_dims] = get_launch_args( - kernel, - out_a.data_size(), - out_a.shape(), - out_a.strides(), - large()); - kernel<<>>( - a.data(), - b.data(), - out_a.data(), - out_b.data(), - out_a.data_size()); - }); - } + const_param(shape), + const_param(a_strides), + const_param(b_strides)); + }); + } else { + auto kernel = cu::binary_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out_a, large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + a.data(), + b.data(), + out_a.data(), + out_b.data(), + out_a.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + ndim); + } + }); } else { - throw std::runtime_error(fmt::format( - "Can not do binary op {} on inputs of {} with result of {}.", - op, - dtype_to_string(a.dtype()), - dtype_to_string(out_a.dtype()))); + dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) { + using IdxT = std::conditional_t; + auto kernel = cu::binary_ss; + if (bopt == BinaryOpType::ScalarVector) { + kernel = cu::binary_sv; + } else if (bopt == BinaryOpType::VectorScalar) { + kernel = cu::binary_vs; + } else if (bopt == BinaryOpType::VectorVector) { + kernel = cu::binary_vv; + } + auto [num_blocks, block_dims] = get_launch_args( + kernel, + out_a.data_size(), + out_a.shape(), + out_a.strides(), + large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + a.data(), + b.data(), + out_a.data(), + out_b.data(), + out_a.data_size()); + }); } - }); + } else { + throw std::runtime_error(fmt::format( + "Can not do binary op {} on inputs of {} with result of {}.", + op, + dtype_to_string(a.dtype()), + dtype_to_string(out_a.dtype()))); + } }); }); } diff --git a/mlx/backend/cuda/compiled.cpp b/mlx/backend/cuda/compiled.cpp index 1aa7ecb92..21257e5dd 100644 --- a/mlx/backend/cuda/compiled.cpp +++ b/mlx/backend/cuda/compiled.cpp @@ -3,6 +3,7 @@ #include "mlx/backend/common/compiled.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/jit_module.h" +#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/graph_utils.h" #include "mlx/primitives.h" @@ -178,6 +179,7 @@ void Compiled::eval_gpu( // Whether to use large index. bool large = compiled_use_large_index(inputs, outputs, contiguous); + cu::KernelArgs args; // Put inputs. int strides_index = 1; for (size_t i = 0; i < inputs.size(); ++i) { @@ -185,26 +187,26 @@ void Compiled::eval_gpu( continue; } const auto& x = inputs[i]; - mod.append_arg(x); + args.append(x); if (!contiguous && !is_scalar(x)) { - mod.append_arg(strides_vec[strides_index++]); + args.append_ptr(strides_vec[strides_index++].data()); } } // Put outputs. compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous); for (auto& x : outputs) { - mod.append_arg(x); + args.append(x); } // Put shape and size. if (!contiguous) { - mod.append_arg(shape); + args.append_ptr(shape.data()); } if (large) { - mod.append_arg(outputs[0].data_size()); + args.append(outputs[0].data_size()); } else { - mod.append_arg(outputs[0].data_size()); + args.append(outputs[0].data_size()); } // Launch kernel. @@ -222,9 +224,10 @@ void Compiled::eval_gpu( for (const auto& out : outputs) { encoder.set_output_array(out); } - encoder.launch_kernel([&](cudaStream_t stream) { - mod.launch_kernel(stream, kernel_name, outputs[0], large); - }); + + auto kernel = mod.get_kernel(kernel_name); + auto [num_blocks, block_dims] = get_launch_args(kernel, outputs[0], large); + encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); } } // namespace mlx::core diff --git a/mlx/backend/cuda/copy/copy_contiguous.cu b/mlx/backend/cuda/copy/copy_contiguous.cu index 15858ded0..408350129 100644 --- a/mlx/backend/cuda/copy/copy_contiguous.cu +++ b/mlx/backend/cuda/copy/copy_contiguous.cu @@ -35,24 +35,25 @@ void copy_contiguous( array& out, int64_t in_offset, int64_t out_offset) { - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto in_type_tag) { - dispatch_all_types(out.dtype(), [&](auto out_type_tag) { - dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { - using InType = cuda_type_t; - using OutType = cuda_type_t; - using IdxT = std::conditional_t; - auto kernel = cu::copy_s; - if (ctype == CopyType::Vector) { - kernel = cu::copy_v; - } - auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), large()); - kernel<<>>( - in.data() + in_offset, - out.data() + out_offset, - out.data_size()); - }); + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + using IdxT = std::conditional_t; + auto kernel = cu::copy_s; + if (ctype == CopyType::Vector) { + kernel = cu::copy_v; + } + auto [num_blocks, block_dims] = get_launch_args( + kernel, out.data_size(), out.shape(), out.strides(), large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + in.data() + in_offset, + out.data() + out_offset, + out.data_size()); }); }); }); diff --git a/mlx/backend/cuda/copy/copy_general.cu b/mlx/backend/cuda/copy/copy_general.cu index b2703e4bf..5c7f9f954 100644 --- a/mlx/backend/cuda/copy/copy_general.cu +++ b/mlx/backend/cuda/copy/copy_general.cu @@ -55,50 +55,54 @@ void copy_general( const Shape& shape, const Strides& strides_in, const Strides& strides_out) { - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto in_type_tag) { - dispatch_all_types(out.dtype(), [&](auto out_type_tag) { - dispatch_bool( - in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, - [&](auto large) { - using InType = cuda_type_t; - using OutType = cuda_type_t; - using IdxT = std::conditional_t; - const InType* in_ptr = in.data() + offset_in; - OutType* out_ptr = out.data() + offset_out; - int ndim = shape.size(); - size_t data_size = 1; - for (auto& s : shape) - data_size *= s; - if (ndim <= 3) { - dispatch_1_2_3(ndim, [&](auto ndim_constant) { - auto kernel = - cu::copy_gg_nd; - auto [num_blocks, block_dims] = get_launch_args( - kernel, data_size, shape, out.strides(), large()); - kernel<<>>( - in_ptr, - out_ptr, - data_size, - const_param(shape), - const_param(strides_in), - const_param(strides_out)); - }); - } else { // ndim >= 4 - auto kernel = cu::copy_gg; + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + dispatch_bool( + in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, + [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + using IdxT = std::conditional_t; + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + int ndim = shape.size(); + size_t data_size = 1; + for (auto& s : shape) + data_size *= s; + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto ndim_constant) { + auto kernel = + cu::copy_gg_nd; auto [num_blocks, block_dims] = get_launch_args( kernel, data_size, shape, out.strides(), large()); - kernel<<>>( + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, in_ptr, out_ptr, data_size, - const_param(shape), - const_param(strides_in), - const_param(strides_out), - ndim); - } - }); - }); + const_param(shape), + const_param(strides_in), + const_param(strides_out)); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_gg; + auto [num_blocks, block_dims] = get_launch_args( + kernel, data_size, shape, out.strides(), large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + in_ptr, + out_ptr, + data_size, + const_param(shape), + const_param(strides_in), + const_param(strides_out), + ndim); + } + }); }); }); } diff --git a/mlx/backend/cuda/copy/copy_general_dynamic.cu b/mlx/backend/cuda/copy/copy_general_dynamic.cu index 68ad005d2..1b643111f 100644 --- a/mlx/backend/cuda/copy/copy_general_dynamic.cu +++ b/mlx/backend/cuda/copy/copy_general_dynamic.cu @@ -61,54 +61,55 @@ void copy_general_dynamic( const Strides& strides_out, const array& dynamic_offset_in, const array& dynamic_offset_out) { - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto in_type_tag) { - dispatch_all_types(out.dtype(), [&](auto out_type_tag) { - dispatch_bool( - in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, - [&](auto large) { - using InType = cuda_type_t; - using OutType = cuda_type_t; - using IdxT = std::conditional_t; - const InType* in_ptr = in.data() + offset_in; - OutType* out_ptr = out.data() + offset_out; - int ndim = shape.size(); - if (ndim <= 3) { - dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto kernel = cu::copy_gg_dynamic_nd< - InType, - OutType, - IdxT, - dims_constant()>; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large()); - kernel<<>>( - in_ptr, - out_ptr, - out.size(), - const_param(shape), - const_param(strides_in), - const_param(strides_out), - dynamic_offset_in.data(), - dynamic_offset_out.data()); - }); - } else { // ndim >= 4 - auto kernel = cu::copy_gg_dynamic; + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + dispatch_bool( + in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, + [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + using IdxT = std::conditional_t; + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = cu:: + copy_gg_dynamic_nd; auto [num_blocks, block_dims] = get_launch_args(kernel, out, large()); - kernel<<>>( + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, in_ptr, out_ptr, out.size(), - const_param(shape), - const_param(strides_in), - const_param(strides_out), - ndim, + const_param(shape), + const_param(strides_in), + const_param(strides_out), dynamic_offset_in.data(), dynamic_offset_out.data()); - } - }); - }); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_gg_dynamic; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + in_ptr, + out_ptr, + out.size(), + const_param(shape), + const_param(strides_in), + const_param(strides_out), + ndim, + dynamic_offset_in.data(), + dynamic_offset_out.data()); + } + }); }); }); } diff --git a/mlx/backend/cuda/copy/copy_general_input.cu b/mlx/backend/cuda/copy/copy_general_input.cu index d83ba0854..1ac7712e6 100644 --- a/mlx/backend/cuda/copy/copy_general_input.cu +++ b/mlx/backend/cuda/copy/copy_general_input.cu @@ -50,45 +50,49 @@ void copy_general_input( int64_t offset_out, const Shape& shape, const Strides& strides_in) { - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto in_type_tag) { - dispatch_all_types(out.dtype(), [&](auto out_type_tag) { - dispatch_bool( - in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, - [&](auto large) { - using InType = cuda_type_t; - using OutType = cuda_type_t; - using IdxT = std::conditional_t; - const InType* in_ptr = in.data() + offset_in; - OutType* out_ptr = out.data() + offset_out; - int ndim = shape.size(); - if (ndim <= 3) { - dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto kernel = - cu::copy_g_nd; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large()); - kernel<<>>( - in_ptr, - out_ptr, - out.size(), - const_param(shape), - const_param(strides_in)); - }); - } else { // ndim >= 4 - auto kernel = cu::copy_g; + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + dispatch_bool( + in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, + [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + using IdxT = std::conditional_t; + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = + cu::copy_g_nd; auto [num_blocks, block_dims] = get_launch_args(kernel, out, large()); - kernel<<>>( + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, in_ptr, out_ptr, out.size(), - const_param(shape), - const_param(strides_in), - ndim); - } - }); - }); + const_param(shape), + const_param(strides_in)); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + in_ptr, + out_ptr, + out.size(), + const_param(shape), + const_param(strides_in), + ndim); + } + }); }); }); } diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index ba31c0e45..fff752fe5 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -2,38 +2,23 @@ #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/worker.h" -#include "mlx/backend/metal/metal.h" +#include "mlx/utils.h" #include #include #include +#include namespace mlx::core { +// Can be tuned with MLX_MAX_OPS_PER_BUFFER +// This should be less than 255 +constexpr int default_max_nodes_per_graph = 20; + +constexpr int max_graph_cache_size = 100; + namespace cu { -DeviceStream::DeviceStream(Device& device) : device_(device), stream_(device) {} - -void DeviceStream::synchronize() { - cudaStreamSynchronize(stream_); -} - -cudaStream_t DeviceStream::schedule_cuda_stream() { - // TODO: Return a stream that maximizes parallelism. - return stream_; -} - -cudaStream_t DeviceStream::last_cuda_stream() { - return stream_; -} - -CommandEncoder& DeviceStream::get_encoder() { - if (!encoder_) { - encoder_ = std::make_unique(*this); - } - return *encoder_; -} - Device::Device(int device) : device_(device) { CHECK_CUDA_ERROR(cudaDeviceGetAttribute( &compute_capability_major_, cudaDevAttrComputeCapabilityMajor, device_)); @@ -67,49 +52,253 @@ void Device::make_current() { } } -DeviceStream& Device::get_stream(Stream s) { - auto it = streams_.find(s.index); - if (it == streams_.end()) { - it = streams_.try_emplace(s.index, *this).first; +CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) { + CHECK_CUDA_ERROR(cudaGraphCreate(&graph, 0)); + CHECK_CUDA_ERROR(cudaStreamBeginCaptureToGraph( + enc.stream(), graph, NULL, NULL, 0, cudaStreamCaptureModeGlobal)); +} + +CommandEncoder::CaptureContext::~CaptureContext() { + CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph)); + size_t num_nodes; + CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes)); + if (num_nodes == 1) { + cudaGraphNode_t captured_node; + CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes)); + CUDA_KERNEL_NODE_PARAMS params; + CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, ¶ms)); + cudaGraphNode_t node; + CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, enc.graph_, NULL, 0, ¶ms)); + enc.insert_graph_dependencies(GraphNode{node, 'K'}); + } else { + cudaGraphNode_t node; + CHECK_CUDA_ERROR( + cudaGraphAddChildGraphNode(&node, enc.graph_, NULL, 0, graph)); + enc.insert_graph_dependencies(GraphNode{node, 'G'}); + } + CHECK_CUDA_ERROR(cudaGraphDestroy(graph)); +} + +CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc) + : enc(enc) { + enc.in_concurrent_ = true; +} + +CommandEncoder::ConcurrentContext::~ConcurrentContext() { + enc.in_concurrent_ = false; + + // Use an empty graph node for synchronization + CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)}; + enc.empty_node_count_++; + CHECK_CUDA_ERROR(cudaGraphAddEmptyNode(&empty.node, enc.graph_, NULL, 0)); + + // Insert the concurrent -> empty node dependencies + for (auto& from : enc.concurrent_nodes_) { + enc.from_nodes_.push_back(from.node); + enc.to_nodes_.push_back(empty.node); + enc.graph_key_ += from.id; + enc.graph_key_ += from.node_type; + enc.graph_key_ += empty.id; + enc.graph_key_ += empty.node_type; + } + + // Insert the input -> concurrent node dependencies without updating output + // nodes + auto outputs = std::move(enc.active_outputs_); + enc.insert_graph_dependencies(std::move(enc.concurrent_nodes_)); + + // Update output node to be the empty node + for (auto o : outputs) { + enc.node_map_.emplace(o, empty).first->second = empty; + } +} + +void CommandEncoder::insert_graph_dependencies(GraphNode node) { + if (node.node_type == 'G') { + graph_node_count_++; + } + node.id = std::to_string(node_count_++); + if (in_concurrent_) { + concurrent_nodes_.push_back(std::move(node)); + } else { + std::vector nodes; + nodes.push_back(std::move(node)); + insert_graph_dependencies(std::move(nodes)); + } +} + +void CommandEncoder::insert_graph_dependencies(std::vector nodes) { + std::vector deps; + { + // Dependencies must be added in the same order to produce a consistent + // topology + std::unordered_set set_deps; + for (auto d : active_deps_) { + if (auto it = node_map_.find(d); it != node_map_.end()) { + auto [_, inserted] = set_deps.insert(it->second.node); + if (inserted) { + deps.push_back(it->second); + } + } + } + } + active_deps_.clear(); + + for (auto o : active_outputs_) { + for (auto& node : nodes) { + node_map_.emplace(o, node).first->second = node; + } + } + active_outputs_.clear(); + + for (auto& from : deps) { + for (auto& to : nodes) { + from_nodes_.push_back(from.node); + to_nodes_.push_back(to.node); + graph_key_ += from.id; + graph_key_ += from.node_type; + graph_key_ += to.id; + graph_key_ += to.node_type; + } + } +} + +CommandEncoder& Device::get_command_encoder(Stream s) { + auto it = encoders_.find(s.index); + if (it == encoders_.end()) { + it = encoders_.try_emplace(s.index, *this).first; } return it->second; } -CommandEncoder::CommandEncoder(DeviceStream& s) - : device_(s.device()), stream_(s) {} +CommandEncoder::CommandEncoder(Device& d) : stream_(d) { + CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0)); +} + +void clear_graphs(std::unordered_map& graphs) { + for (auto& [_, graph_exec] : graphs) { + CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec)); + } + graphs.clear(); +} + +CommandEncoder::~CommandEncoder() { + clear_graphs(graph_cache_); +} void CommandEncoder::add_completed_handler(std::function task) { worker_.add_task(std::move(task)); } -void CommandEncoder::end_encoding() { - if (!temporaries_.empty()) { - add_completed_handler([temporaries = std::move(temporaries_)]() {}); - } +void CommandEncoder::set_input_array(const array& arr) { + auto id = reinterpret_cast(arr.buffer().ptr()); + active_deps_.push_back(id); +} - // There is no kernel running, run completion handlers immediately. - if (!has_gpu_work_) { - worker_.consume_in_this_thread(); - return; - } - has_gpu_work_ = false; +void CommandEncoder::set_output_array(const array& arr) { + auto id = reinterpret_cast(arr.buffer().ptr()); + active_deps_.push_back(id); + active_outputs_.push_back(id); +} - // Put completion handlers in a batch. - worker_.end_batch(); - - // Signaling kernel completion is expensive, delay until enough batches. - // TODO: This number is arbitrarily picked, profile for a better stragety. - if (worker_.uncommited_batches() > 8) { +void CommandEncoder::maybe_commit() { + if (node_count_ >= env::max_ops_per_buffer(default_max_nodes_per_graph)) { commit(); } } +void CommandEncoder::add_kernel_node( + void* func, + dim3 grid_dim, + dim3 block_dim, + void** params) { + cudaKernelNodeParams kernel_params = {0}; + kernel_params.func = func; + kernel_params.gridDim = grid_dim; + kernel_params.blockDim = block_dim; + kernel_params.kernelParams = params; + cudaGraphNode_t node; + CHECK_CUDA_ERROR( + cudaGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params)); + insert_graph_dependencies(GraphNode{node, 'K'}); +} + +void CommandEncoder::add_kernel_node( + CUfunction func, + dim3 grid_dim, + dim3 block_dim, + void** params) { + CUDA_KERNEL_NODE_PARAMS kernel_params = {0}; + kernel_params.func = func; + kernel_params.gridDimX = grid_dim.x; + kernel_params.gridDimY = grid_dim.y; + kernel_params.gridDimZ = grid_dim.z; + kernel_params.blockDimX = block_dim.x; + kernel_params.blockDimY = block_dim.y; + kernel_params.blockDimZ = block_dim.z; + kernel_params.kernelParams = params; + CUgraphNode node; + CHECK_CUDA_ERROR( + cuGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params)); + insert_graph_dependencies(GraphNode{node, 'K'}); +} + void CommandEncoder::commit() { - worker_.commit(stream_.last_cuda_stream()); + if (!temporaries_.empty()) { + add_completed_handler([temporaries = std::move(temporaries_)]() {}); + } + if (node_count_ > 0) { + if (!from_nodes_.empty()) { + CHECK_CUDA_ERROR(cudaGraphAddDependencies( + graph_, from_nodes_.data(), to_nodes_.data(), from_nodes_.size())); + } + // TODO smarter cache policy + if (graph_cache_.size() > max_graph_cache_size) { + clear_graphs(graph_cache_); + } + + graph_key_ += "."; + graph_key_ += std::to_string(node_count_); + graph_key_ += "."; + graph_key_ += std::to_string(graph_node_count_); + graph_key_ += "."; + graph_key_ += std::to_string(empty_node_count_); + auto [it, _] = graph_cache_.emplace(graph_key_, nullptr); + auto& graph_exec = it->second; + + if (graph_exec != NULL) { + cudaGraphExecUpdateResultInfo update_result; + cudaGraphExecUpdate(graph_exec, graph_, &update_result); + if (update_result.result != cudaGraphExecUpdateSuccess) { + cudaGetLastError(); + CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec)); + graph_exec = NULL; + } + } + if (graph_exec == NULL) { + CHECK_CUDA_ERROR( + cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0)); + } + CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_)); + + // Reset state + node_count_ = 0; + graph_node_count_ = 0; + from_nodes_.clear(); + to_nodes_.clear(); + graph_key_.clear(); + node_map_.clear(); + CHECK_CUDA_ERROR(cudaGraphDestroy(graph_)); + CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0)); + } + + // Put completion handlers in a batch. + worker_.end_batch(); + worker_.commit(stream_); } void CommandEncoder::synchronize() { - stream().synchronize(); + cudaStreamSynchronize(stream_); auto p = std::make_shared>(); std::future f = p->get_future(); add_completed_handler([p = std::move(p)]() { p->set_value(); }); @@ -127,12 +316,8 @@ Device& device(mlx::core::Device device) { return it->second; } -DeviceStream& get_stream(Stream s) { - return device(s.device).get_stream(s); -} - CommandEncoder& get_command_encoder(Stream s) { - return get_stream(s).get_encoder(); + return device(s.device).get_command_encoder(s); } } // namespace cu diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index 744f77f62..4ebdae55c 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -7,41 +7,108 @@ #include "mlx/stream.h" #include +#include #include #include namespace mlx::core::cu { -class Device; -class CommandEncoder; - -class DeviceStream { +class CommandEncoder { public: - explicit DeviceStream(Device& device); + struct CaptureContext { + CaptureContext(CommandEncoder& enc); + ~CaptureContext(); + cudaGraph_t graph; + CommandEncoder& enc; + }; + struct ConcurrentContext { + ConcurrentContext(CommandEncoder& enc); + ~ConcurrentContext(); + CommandEncoder& enc; + }; - DeviceStream(const DeviceStream&) = delete; - DeviceStream& operator=(const DeviceStream&) = delete; + explicit CommandEncoder(Device& d); + ~CommandEncoder(); - // Wait until kernels in the stream complete. - void synchronize(); + CommandEncoder(const CommandEncoder&) = delete; + CommandEncoder& operator=(const CommandEncoder&) = delete; - // Return a cuda stream for launching kernels. - cudaStream_t schedule_cuda_stream(); - - // Return the last cuda stream used. - cudaStream_t last_cuda_stream(); - - CommandEncoder& get_encoder(); - - Device& device() { - return device_; + CaptureContext capture_context() { + return CaptureContext{*this}; + } + ConcurrentContext concurrent_context() { + return ConcurrentContext{*this}; } + void set_input_array(const array& arr); + void set_output_array(const array& arr); + + template + void + add_kernel_node(F* func, dim3 grid_dim, dim3 block_dim, Params&&... params) { + constexpr size_t num = sizeof...(Params); + void* ptrs[num]; + size_t i = 0; + ([&](auto&& p) { ptrs[i++] = static_cast(&p); }( + std::forward(params)), + ...); + add_kernel_node((void*)func, grid_dim, block_dim, ptrs); + } + + void add_kernel_node( + CUfunction func, + dim3 grid_dim, + dim3 block_dim, + void** params); + + void + add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params); + + void add_temporary(const array& arr) { + temporaries_.push_back(arr.data_shared_ptr()); + } + + void add_completed_handler(std::function task); + void maybe_commit(); + void commit(); + + CudaStream& stream() { + return stream_; + } + + // Wait until kernels and completion handlers are finished + void synchronize(); + private: - Device& device_; + struct GraphNode { + cudaGraphNode_t node; + // K = kernel + // E = empty + // G = subgraph + char node_type; + std::string id; + }; + + void insert_graph_dependencies(GraphNode node); + void insert_graph_dependencies(std::vector nodes); + CudaStream stream_; - std::unique_ptr encoder_; + cudaGraph_t graph_; + Worker worker_; + char node_count_{0}; + char graph_node_count_{0}; + char empty_node_count_{0}; + bool in_concurrent_{false}; + std::vector from_nodes_; + std::vector to_nodes_; + std::string graph_key_; + std::vector concurrent_nodes_; + std::vector> temporaries_; + std::unordered_map graph_cache_; + std::vector active_deps_; + std::vector active_outputs_; + std::unordered_map node_map_; }; class Device { @@ -55,7 +122,7 @@ class Device { // Make this device the current cuda device, required by some cuda calls. void make_current(); - DeviceStream& get_stream(Stream s); + CommandEncoder& get_command_encoder(Stream s); int cuda_device() const { return device_; @@ -75,67 +142,10 @@ class Device { int compute_capability_major_; int compute_capability_minor_; cublasLtHandle_t lt_; - std::unordered_map streams_; -}; - -class CommandEncoder { - public: - explicit CommandEncoder(DeviceStream& stream); - - CommandEncoder(const CommandEncoder&) = delete; - CommandEncoder& operator=(const CommandEncoder&) = delete; - - void set_input_array(const array& arr) {} - void set_output_array(const array& arr) {} - - void add_temporary(const array& arr) { - temporaries_.push_back(arr.data_shared_ptr()); - } - - void add_completed_handler(std::function task); - void end_encoding(); - void commit(); - - // Schedule a cuda stream for |fun| to launch kernels, and check error - // afterwards. - template - void launch_kernel(F&& fun) { - launch_kernel(stream_.schedule_cuda_stream(), std::forward(fun)); - } - - template - void launch_kernel(cudaStream_t stream, F&& fun) { - device_.make_current(); - fun(stream); - check_cuda_error("kernel launch", cudaGetLastError()); - has_gpu_work_ = true; - } - - Device& device() { - return device_; - } - - DeviceStream& stream() { - return stream_; - } - - bool has_gpu_work() const { - return has_gpu_work_; - } - - // Wait until kernels and completion handlers are finished - void synchronize(); - - private: - Device& device_; - DeviceStream& stream_; - Worker worker_; - bool has_gpu_work_{false}; - std::vector> temporaries_; + std::unordered_map encoders_; }; Device& device(mlx::core::Device device); -DeviceStream& get_stream(Stream s); CommandEncoder& get_command_encoder(Stream s); // Return an execution policy that does not sync for result. diff --git a/mlx/backend/cuda/eval.cpp b/mlx/backend/cuda/eval.cpp index 21b019cd8..40beb12d2 100644 --- a/mlx/backend/cuda/eval.cpp +++ b/mlx/backend/cuda/eval.cpp @@ -37,22 +37,20 @@ void eval(array& arr) { } auto& encoder = cu::get_command_encoder(arr.primitive().stream()); - if (encoder.has_gpu_work()) { - // Keep used buffers alive until kernel finishes running. - std::unordered_set> buffers; - for (auto& in : arr.inputs()) { - buffers.insert(in.data_shared_ptr()); - } - for (auto& s : arr.siblings()) { - buffers.insert(s.data_shared_ptr()); - } - // Remove the output if it was donated to by an input. - if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { - buffers.erase(it); - } - encoder.add_completed_handler([buffers = std::move(buffers)]() {}); + // Keep used buffers alive until kernel finishes running. + std::unordered_set> buffers; + for (auto& in : arr.inputs()) { + buffers.insert(in.data_shared_ptr()); } - encoder.end_encoding(); + for (auto& s : arr.siblings()) { + buffers.insert(s.data_shared_ptr()); + } + // Remove the output if it was donated to by an input. + if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { + buffers.erase(it); + } + encoder.add_completed_handler([buffers = std::move(buffers)]() {}); + encoder.maybe_commit(); } void finalize(Stream s) { diff --git a/mlx/backend/cuda/event.cu b/mlx/backend/cuda/event.cu index 9fc5c641b..afa032a83 100644 --- a/mlx/backend/cuda/event.cu +++ b/mlx/backend/cuda/event.cu @@ -61,7 +61,9 @@ void CudaEvent::wait(Stream s) { if (s.device == mlx::core::Device::cpu) { scheduler::enqueue(s, [*this]() mutable { wait(); }); } else { - wait(cu::get_stream(s).last_cuda_stream()); + auto& enc = cu::get_command_encoder(s); + enc.commit(); + wait(enc.stream()); } } @@ -74,7 +76,9 @@ void CudaEvent::record(Stream s) { if (s.device == mlx::core::Device::cpu) { throw std::runtime_error("CudaEvent can not wait on cpu stream."); } else { - record(cu::get_stream(s).last_cuda_stream()); + auto& enc = cu::get_command_encoder(s); + enc.commit(); + record(enc.stream()); } } @@ -136,11 +140,9 @@ void SharedEvent::wait(Stream s, uint64_t value) { scheduler::enqueue(s, [*this, value]() mutable { wait(value); }); } else { auto& encoder = get_command_encoder(s); - encoder.launch_kernel( - encoder.stream().last_cuda_stream(), - [this, value](cudaStream_t stream) { wait(stream, value); }); + encoder.commit(); + wait(encoder.stream(), value); encoder.add_completed_handler([ac = ac_]() {}); - encoder.end_encoding(); } } @@ -162,11 +164,9 @@ void SharedEvent::signal(Stream s, uint64_t value) { scheduler::enqueue(s, [*this, value]() mutable { signal(stream, value); }); } else { auto& encoder = get_command_encoder(s); - encoder.launch_kernel( - encoder.stream().last_cuda_stream(), - [this, value](cudaStream_t stream) { signal(stream, value); }); + encoder.commit(); + signal(encoder.stream(), value); encoder.add_completed_handler([ac = ac_]() {}); - encoder.end_encoding(); } } diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp index 65a175fbd..4b03a604e 100644 --- a/mlx/backend/cuda/indexing.cpp +++ b/mlx/backend/cuda/indexing.cpp @@ -3,13 +3,16 @@ #include "mlx/backend/common/compiled.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/jit_module.h" +#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" #include "cuda_jit_sources.h" +#include #include +#include #include #include @@ -22,7 +25,7 @@ namespace { constexpr const char* g_scatter_ops[] = {"Max", "Min", "Sum", "Prod", "Assign"}; void append_indices_arg( - cu::JitModule& mod, + cu::KernelArgs& args, const std::vector& inputs, int nidx, int idx_ndim) { @@ -30,7 +33,7 @@ void append_indices_arg( for (int i = 0; i < nidx; ++i) { indices[i] = inputs[i + 1].data(); } - mod.append_arg(std::move(indices)); + args.append(std::move(indices)); std::vector indices_shape(nidx * idx_ndim); for (int i = 0; i < nidx; ++i) { std::copy_n( @@ -38,7 +41,7 @@ void append_indices_arg( idx_ndim, indices_shape.data() + i * idx_ndim); } - mod.append_arg(std::move(indices_shape)); + args.append(std::move(indices_shape)); std::vector indices_strides(nidx * idx_ndim); for (int i = 0; i < nidx; ++i) { std::copy_n( @@ -46,7 +49,7 @@ void append_indices_arg( idx_ndim, indices_strides.data() + i * idx_ndim); } - mod.append_arg(std::move(indices_strides)); + args.append(std::move(indices_strides)); } } // namespace @@ -94,20 +97,21 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { return std::make_pair(jit_source_gather, std::move(kernel_names)); }); - mod.append_arg(src); - mod.append_arg(out); + cu::KernelArgs args; + args.append(src); + args.append(out); if (large) { - mod.append_arg(out.size()); + args.append(out.size()); } else { - mod.append_arg(out.size()); + args.append(out.size()); } - mod.append_ndim_arg(src.shape()); - mod.append_ndim_arg(src.strides()); - mod.append_arg(src.ndim()); - mod.append_ndim_arg(slice_sizes_); - mod.append_arg(slice_size); - mod.append_arg(axes_); - append_indices_arg(mod, inputs, nidx, idx_ndim); + args.append_ndim(src.shape()); + args.append_ndim(src.strides()); + args.append(src.ndim()); + args.append_ndim(slice_sizes_); + args.append(slice_size); + args.append(axes_); + append_indices_arg(args, inputs, nidx, idx_ndim); std::string kernel_name = fmt::format( "mlx::core::cu::gather<{}, {}, {}, {}, {}>", @@ -122,9 +126,10 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(in); } encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - mod.launch_kernel(stream, kernel_name, out, large); - }); + + auto kernel = mod.get_kernel(kernel_name); + auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); } void Scatter::eval_gpu(const std::vector& inputs, array& out) { @@ -187,26 +192,27 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { return std::make_pair(jit_source_scatter, std::move(kernel_names)); }); - mod.append_arg(upd); - mod.append_arg(out); + cu::KernelArgs args; + args.append(upd); + args.append(out); if (large) { - mod.append_arg(upd.size()); + args.append(upd.size()); } else { - mod.append_arg(upd.size()); + args.append(upd.size()); } - mod.append_ndim_arg(upd.shape()); - mod.append_ndim_arg(upd.strides()); - mod.append_arg(upd.ndim()); + args.append_ndim(upd.shape()); + args.append_ndim(upd.strides()); + args.append(upd.ndim()); if (large) { - mod.append_arg(upd_post_idx_size); + args.append(upd_post_idx_size); } else { - mod.append_arg(upd_post_idx_size); + args.append(upd_post_idx_size); } - mod.append_ndim_arg(out.shape()); - mod.append_ndim_arg(out.strides()); - mod.append_arg(out.ndim()); - mod.append_arg(axes_); - append_indices_arg(mod, inputs, nidx, idx_ndim); + args.append_ndim(out.shape()); + args.append_ndim(out.strides()); + args.append(out.ndim()); + args.append(axes_); + append_indices_arg(args, inputs, nidx, idx_ndim); std::string kernel_name = fmt::format( "mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>", @@ -222,9 +228,9 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(in); } encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - mod.launch_kernel(stream, kernel_name, upd, large); - }); + auto kernel = mod.get_kernel(kernel_name); + auto [num_blocks, block_dims] = get_launch_args(kernel, upd, large); + encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); } void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { @@ -275,25 +281,26 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { } size_t idx_size_axis = idx.shape(axis_); - mod.append_arg(src); - mod.append_arg(idx); - mod.append_arg(out); + cu::KernelArgs args; + args.append(src); + args.append(idx); + args.append(out); if (large) { - mod.append_arg(idx_size_pre); - mod.append_arg(idx_size_axis); - mod.append_arg(idx_size_post); + args.append(idx_size_pre); + args.append(idx_size_axis); + args.append(idx_size_post); } else { - mod.append_arg(idx_size_pre); - mod.append_arg(idx_size_axis); - mod.append_arg(idx_size_post); + args.append(idx_size_pre); + args.append(idx_size_axis); + args.append(idx_size_post); } - mod.append_arg(remove_index(idx.shape(), axis_)); - mod.append_arg(remove_index(src.strides(), axis_)); - mod.append_arg(remove_index(idx.strides(), axis_)); - mod.append_arg(axis_); - mod.append_arg(src.shape(axis_)); - mod.append_arg(src.strides(axis_)); - mod.append_arg(idx.strides(axis_)); + args.append(remove_index(idx.shape(), axis_)); + args.append(remove_index(src.strides(), axis_)); + args.append(remove_index(idx.strides(), axis_)); + args.append(axis_); + args.append(src.shape(axis_)); + args.append(src.strides(axis_)); + args.append(idx.strides(axis_)); std::string kernel_name = fmt::format( "mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>", @@ -309,9 +316,9 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(in); } encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - mod.launch_kernel(stream, kernel_name, idx, large); - }); + auto kernel = mod.get_kernel(kernel_name); + auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large); + encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); } void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { @@ -377,25 +384,26 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { } size_t idx_size_axis = idx.shape(axis_); - mod.append_arg(upd); - mod.append_arg(idx); - mod.append_arg(out); + cu::KernelArgs args; + args.append(upd); + args.append(idx); + args.append(out); if (large) { - mod.append_arg(idx_size_pre); - mod.append_arg(idx_size_axis); - mod.append_arg(idx_size_post); + args.append(idx_size_pre); + args.append(idx_size_axis); + args.append(idx_size_post); } else { - mod.append_arg(idx_size_pre); - mod.append_arg(idx_size_axis); - mod.append_arg(idx_size_post); + args.append(idx_size_pre); + args.append(idx_size_axis); + args.append(idx_size_post); } - mod.append_arg(remove_index(idx.shape(), axis_)); - mod.append_arg(remove_index(upd.strides(), axis_)); - mod.append_arg(remove_index(idx.strides(), axis_)); - mod.append_arg(axis_); - mod.append_arg(out.shape(axis_)); - mod.append_arg(upd.strides(axis_)); - mod.append_arg(idx.strides(axis_)); + args.append(remove_index(idx.shape(), axis_)); + args.append(remove_index(upd.strides(), axis_)); + args.append(remove_index(idx.strides(), axis_)); + args.append(axis_); + args.append(out.shape(axis_)); + args.append(upd.strides(axis_)); + args.append(idx.strides(axis_)); std::string kernel_name = fmt::format( "mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>", @@ -412,9 +420,9 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(in); } encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - mod.launch_kernel(stream, kernel_name, idx, large); - }); + auto kernel = mod.get_kernel(kernel_name); + auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large); + encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); } } // namespace mlx::core diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index af8f7dc75..5bc56b25e 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -26,16 +26,6 @@ void check_nvrtc_error(const char* name, nvrtcResult err) { } } -#define CHECK_CU_ERROR(cmd) check_cu_error(#cmd, (cmd)) - -void check_cu_error(const char* name, CUresult err) { - if (err != CUDA_SUCCESS) { - const char* err_str = "Unknown error"; - cuGetErrorString(err, &err_str); - throw std::runtime_error(fmt::format("{} failed: {}", name, err_str)); - } -} - // Return the location of the CUDA toolkit. const std::string& cuda_home() { static std::string home = []() -> std::string { @@ -280,60 +270,13 @@ JitModule::JitModule( // Load kernels. for (const auto& [name, mangled] : ptx_kernels) { CUfunction kernel; - CHECK_CU_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str())); + CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str())); kernels_[name] = kernel; } } JitModule::~JitModule() { - CHECK_CU_ERROR(cuModuleUnload(module_)); -} - -void JitModule::launch_kernel( - CUstream stream, - const std::string& kernel_name, - const array& arr, - bool large, - int work_per_thread) { - CUfunction kernel = get_kernel(kernel_name); - size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread); - int _, block_dim; - CHECK_CU_ERROR( - cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0)); - if (block_dim > nthreads) { - block_dim = nthreads; - } - Dims num_blocks{1, 1, 1}; - if (large) { - num_blocks = - get_2d_grid_dims_common(arr.shape(), arr.strides(), work_per_thread); - std::get<0>(num_blocks) = - (std::get<0>(num_blocks) + block_dim - 1) / block_dim; - } else { - std::get<0>(num_blocks) = (nthreads + block_dim - 1) / block_dim; - } - launch_kernel(stream, kernel, num_blocks, Dims{block_dim, 1, 1}); -} - -void JitModule::launch_kernel( - CUstream stream, - CUfunction kernel, - Dims num_blocks, - Dims block_dims) { - CHECK_CU_ERROR(cuLaunchKernel( - kernel, - std::get<0>(num_blocks), - std::get<1>(num_blocks), - std::get<2>(num_blocks), - std::get<0>(block_dims), - std::get<1>(block_dims), - std::get<2>(block_dims), - 0, - stream, - args_.data(), - nullptr)); - args_.clear(); - storage_.clear(); + CHECK_CUDA_ERROR(cuModuleUnload(module_)); } CUfunction JitModule::get_kernel(const std::string& kernel_name) { @@ -345,10 +288,6 @@ CUfunction JitModule::get_kernel(const std::string& kernel_name) { return it->second; } -void JitModule::append_ptr_arg(const void* v) { - args_.push_back(const_cast(v)); -} - JitModule& get_jit_module( const mlx::core::Device& device, const std::string& name, diff --git a/mlx/backend/cuda/jit_module.h b/mlx/backend/cuda/jit_module.h index bbfaa45b0..57da7c87e 100644 --- a/mlx/backend/cuda/jit_module.h +++ b/mlx/backend/cuda/jit_module.h @@ -4,6 +4,7 @@ #include "mlx/array.h" #include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/config.h" #include @@ -23,72 +24,48 @@ using KernelBuilderResult = std::pair< /* kernel names */ std::vector>; using KernelBuilder = std::function; -class JitModule { - public: - JitModule( - Device& device, - const std::string& module_name, - const KernelBuilder& builder); - ~JitModule(); +struct KernelArgs { + void** args() { + return args_.data(); + } - JitModule(const JitModule&) = delete; - JitModule& operator=(const JitModule&) = delete; - - void append_arg(const array& a) { - append_arg(reinterpret_cast(a.data())); + void append(const array& a) { + append(reinterpret_cast(a.data())); } template - void append_arg(T val) { + void append(T val) { storage_.emplace_back(val); - append_ptr_arg(&storage_.back()); + append_ptr(&storage_.back()); } template - void append_arg(std::vector vec) { + void append(std::vector vec) { if (vec.empty()) { // The nullptr can not be used as arg, pass something not null. - append_arg(std::monostate{}); + append(std::monostate{}); } else { - append_ptr_arg(vec.data()); + append_ptr(vec.data()); storage_.emplace_back(std::move(vec)); } } // Make sure the arg is copied to an array with size of NDIM. template - void append_ndim_arg(const std::vector& vec) { + void append_ndim(std::vector vec) { if (vec.size() > NDIM) { throw std::runtime_error( fmt::format("ndim can not be larger than {}.", NDIM)); } - std::vector copied(NDIM); - std::copy(vec.begin(), vec.end(), copied.data()); - append_arg(std::move(copied)); + vec.resize(NDIM); + append(std::move(vec)); } - // Launch kernel with |kernel_name| that each thread works on - // |work_per_thread| elements of |arr|. - void launch_kernel( - CUstream stream, - const std::string& kernel_name, - const array& arr, - bool large, - int work_per_thread = 1); - - void launch_kernel( - CUstream stream, - CUfunction kernel, - Dims num_blocks, - Dims block_dims); - - CUfunction get_kernel(const std::string& kernel_name); + void append_ptr(const void* v) { + args_.push_back(const_cast(v)); + } private: - void append_ptr_arg(const void* v); - - CUmodule module_{nullptr}; - std::unordered_map kernels_; std::vector args_; // The cuLaunchKernel API requires passing pointers to arguments so store @@ -105,6 +82,23 @@ class JitModule { std::deque storage_; }; +class JitModule { + public: + JitModule( + Device& device, + const std::string& module_name, + const KernelBuilder& builder); + ~JitModule(); + + JitModule(const JitModule&) = delete; + JitModule& operator=(const JitModule&) = delete; + CUfunction get_kernel(const std::string& kernel_name); + + private: + CUmodule module_{nullptr}; + std::unordered_map kernels_; +}; + JitModule& get_jit_module( const mlx::core::Device& device, const std::string& name, diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index b0058b618..eeaf916c1 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -12,6 +12,7 @@ #include "mlx/backend/cuda/device/utils.cuh" #include +#include #include #include #include @@ -120,7 +121,13 @@ std::pair get_grid_and_block(int dim0, int dim1, int dim2); template inline uint max_occupancy_block_dim(T kernel) { int _, block_dim; - CHECK_CUDA_ERROR(cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel)); + if constexpr (std::is_same_v) { + CHECK_CUDA_ERROR( + cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0)); + } else { + CHECK_CUDA_ERROR( + cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel)); + } return block_dim; } diff --git a/mlx/backend/cuda/layer_norm.cu b/mlx/backend/cuda/layer_norm.cu index 9a9fbcb37..5fbf949d7 100644 --- a/mlx/backend/cuda/layer_norm.cu +++ b/mlx/backend/cuda/layer_norm.cu @@ -258,23 +258,23 @@ void LayerNorm::eval_gpu( encoder.set_input_array(w); encoder.set_input_array(b); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) { - constexpr uint32_t N_READS = 4; - dispatch_block_dim( - cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - using DataType = cuda_type_t; - auto kernel = cu::layer_norm; - kernel<<>>( - x.data(), - w.data(), - b.data(), - out.data(), - eps_, - axis_size, - w_stride, - b_stride); - }); + dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) { + constexpr uint32_t N_READS = 4; + dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::layer_norm; + encoder.add_kernel_node( + kernel, + n_rows, + block_dim(), + x.data(), + w.data(), + b.data(), + out.data(), + eps_, + axis_size, + w_stride, + b_stride); }); }); } @@ -289,21 +289,25 @@ void LayerNormVJP::eval_gpu( // Ensure row contiguity. We could relax this step by checking that the array // is contiguous (no broadcasts or holes) and that the input strides are the // same as the cotangent strides but for now this is simpler. - auto check_input = [&s](const array& x) -> std::pair { + auto check_input = [&s](const array& x, bool& copied) { if (x.flags().row_contiguous) { - return {x, false}; + copied = false; + return x; } + copied = true; array x_copy(x.shape(), x.dtype(), nullptr, {}); copy_gpu(x, x_copy, CopyType::General, s); - return {x_copy, true}; + return x_copy; }; bool donate_x = inputs[0].is_donatable(); bool donate_g = inputs[3].is_donatable(); - auto [x, copied] = check_input(inputs[0]); + bool copied; + auto x = check_input(inputs[0], copied); donate_x |= copied; const array& w = inputs[1]; const array& b = inputs[2]; - auto [g, g_copied] = check_input(inputs[3]); + bool g_copied; + auto g = check_input(inputs[3], g_copied); donate_g |= g_copied; array& gx = outputs[0]; array& gw = outputs[1]; @@ -334,8 +338,10 @@ void LayerNormVJP::eval_gpu( // gradient accumulators. array gw_temp = (has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w; + bool g_in_gw = false; if (has_w) { if (!g_in_gx && donate_g) { + g_in_gw = true; gw_temp.copy_shared_buffer(g); } else { gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); @@ -343,41 +349,47 @@ void LayerNormVJP::eval_gpu( } } - // Finish with the gradient for b in case we had a b. - if (gb.ndim() == 1 && gb.size() == axis_size) { + // The gradient for b in case we had a b. + bool has_gb = (gb.ndim() == 1 && gb.size() == axis_size); + if (has_gb) { ReductionPlan plan( ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); col_reduce(encoder, g, gb, Reduce::ReduceType::Sum, {0}, plan); } + // Insert dependency if `g` was donated + if ((g_in_gx || g_in_gw) && has_gb) { + encoder.set_input_array(gb); + } encoder.set_input_array(x); encoder.set_input_array(w); encoder.set_input_array(g); encoder.set_output_array(gx); encoder.set_output_array(gw_temp); - encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) { - dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) { - dispatch_bool(has_w, [&](auto has_w_constant) { - constexpr int N_READS = 4; - dispatch_block_dim( - cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - using DataType = cuda_type_t; - auto kernel = cu::layer_norm_vjp< - DataType, - has_w_constant.value, - block_dim(), - N_READS>; - kernel<<>>( - x.data(), - w.data(), - g.data(), - gx.data(), - gw_temp.data(), - eps_, - axis_size, - w_stride); - }); - }); + dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) { + dispatch_bool(has_w, [&](auto has_w_constant) { + constexpr int N_READS = 4; + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::layer_norm_vjp< + DataType, + has_w_constant.value, + block_dim(), + N_READS>; + encoder.add_kernel_node( + kernel, + n_rows, + block_dim(), + x.data(), + w.data(), + g.data(), + gx.data(), + gw_temp.data(), + eps_, + axis_size, + w_stride); + }); }); }); diff --git a/mlx/backend/cuda/logsumexp.cu b/mlx/backend/cuda/logsumexp.cu index 5d6bf437d..afc52826f 100644 --- a/mlx/backend/cuda/logsumexp.cu +++ b/mlx/backend/cuda/logsumexp.cu @@ -143,16 +143,18 @@ void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(in); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) { - constexpr int N_READS = 4; - dispatch_block_dim( - cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - using DataType = cuda_type_t; - auto kernel = cu::logsumexp; - kernel<<>>( - in.data(), out.data(), axis_size); - }); + dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) { + constexpr int N_READS = 4; + dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::logsumexp; + encoder.add_kernel_node( + kernel, + n_rows, + block_dim(), + in.data(), + out.data(), + axis_size); }); }); } diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index c32cecc03..e11c68b7d 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -42,7 +42,8 @@ class MatMul { int64_t ldb, int32_t batch_count, int64_t a_batch_stride, - int64_t b_batch_stride) { + int64_t b_batch_stride) + : handle_(device.lt_handle()) { heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; auto scale_type = dtype_to_cuda_type(dtype); @@ -147,7 +148,7 @@ class MatMul { if (heuristic_.state != CUBLAS_STATUS_SUCCESS) { int ret = 0; CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic( - encoder.device().lt_handle(), + handle_, matmul_desc_, a_desc_, b_desc_, @@ -172,25 +173,24 @@ class MatMul { workspace_ptr = workspace.data(); } - encoder.launch_kernel([&](cudaStream_t stream) { - CHECK_CUBLAS_ERROR(cublasLtMatmul( - encoder.device().lt_handle(), - matmul_desc_, - &alpha, - a, - a_desc_, - b, - b_desc_, - &beta, - c ? c : out, - c ? c_desc_ : out_desc_, - out, - out_desc_, - &heuristic_.algo, - workspace_ptr, - heuristic_.workspaceSize, - stream)); - }); + auto capture = encoder.capture_context(); + CHECK_CUBLAS_ERROR(cublasLtMatmul( + handle_, + matmul_desc_, + &alpha, + a, + a_desc_, + b, + b_desc_, + &beta, + c ? c : out, + c ? c_desc_ : out_desc_, + out, + out_desc_, + &heuristic_.algo, + workspace_ptr, + heuristic_.workspaceSize, + encoder.stream())); } private: @@ -259,6 +259,7 @@ class MatMul { return desc; } + cublasLtHandle_t handle_{nullptr}; cublasLtMatmulDesc_t matmul_desc_{nullptr}; cublasLtMatmulPreference_t pref_{nullptr}; cublasLtMatrixLayout_t a_desc_{nullptr}; @@ -273,7 +274,7 @@ class MatMul { namespace { std::tuple -check_transpose(std::vector& copies, const Stream& s, const array& arr) { +check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) { auto stx = arr.strides()[arr.ndim() - 2]; auto sty = arr.strides()[arr.ndim() - 1]; if (sty == 1 && stx == arr.shape(-1)) { @@ -283,7 +284,7 @@ check_transpose(std::vector& copies, const Stream& s, const array& arr) { } else { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); copy_gpu(arr, arr_copy, CopyType::General, s); - copies.push_back(arr_copy); + enc.add_temporary(arr_copy); return std::make_tuple(false, arr.shape(-1), arr_copy); } } @@ -317,13 +318,8 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { // Keep a vector with copies to be cleared in the completed buffer to release // the arrays - std::vector copies; - auto [a_transposed, lda, a] = check_transpose(copies, s, a_pre); - auto [b_transposed, ldb, b] = check_transpose(copies, s, b_pre); - - for (auto& temp : copies) { - encoder.add_temporary(temp); - } + auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); + auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); ///////////////////////////////////////////////////////////////////////////// // Check and collapse batch dimensions @@ -348,7 +344,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { // Invoke cublasLt cu::MatMul matmul( - encoder.device(), + cu::device(s.device), a.dtype(), a_transposed, M, @@ -373,6 +369,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); + auto concurrent = encoder.concurrent_context(); for (size_t i = 0; i < nbatch; ++i) { matmul.run( encoder, @@ -405,14 +402,9 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { // Keep a vector with copies to be cleared in the completed buffer to release // the arrays - std::vector copies; - auto [a_transposed, lda, a] = check_transpose(copies, s, a_pre); - auto [b_transposed, ldb, b] = check_transpose(copies, s, b_pre); - auto [c_transposed, ldc, c] = check_transpose(copies, s, c_pre); - - for (auto& temp : copies) { - encoder.add_temporary(temp); - } + auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); + auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); + auto [c_transposed, ldc, c] = check_transpose(encoder, s, c_pre); ///////////////////////////////////////////////////////////////////////////// // Check and collapse batch dimensions @@ -440,7 +432,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { // Invoke cublasLt cu::MatMul matmul( - encoder.device(), + cu::device(s.device), a.dtype(), a_transposed, M, @@ -478,6 +470,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); + auto concurrent = encoder.concurrent_context(); for (size_t i = 0; i < nbatch; ++i) { matmul.run( encoder, diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 715e5a232..18fa45a33 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -24,23 +24,21 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { if (out.size() == 0) { return; } - auto& s = stream(); - auto& encoder = cu::get_command_encoder(s); + auto& encoder = cu::get_command_encoder(stream()); encoder.set_output_array(out); - encoder.launch_kernel([&, this](cudaStream_t stream) { - dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { - using CTYPE = MLX_GET_TYPE(type_tag); - using OutType = cuda_type_t; - CTYPE step = - static_cast(start_ + step_) - static_cast(start_); - thrust::transform( - cu::thrust_policy(stream), - thrust::counting_iterator(0), - thrust::counting_iterator(out.data_size()), - thrust::device_pointer_cast(out.data()), - cu::Arange{ - static_cast(start_), static_cast(step)}); - }); + auto capture = encoder.capture_context(); + dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { + using CTYPE = MLX_GET_TYPE(type_tag); + using OutType = cuda_type_t; + CTYPE step = + static_cast(start_ + step_) - static_cast(start_); + thrust::transform( + cu::thrust_policy(encoder.stream()), + thrust::counting_iterator(0), + thrust::counting_iterator(out.data_size()), + thrust::device_pointer_cast(out.data()), + cu::Arange{ + static_cast(start_), static_cast(step)}); }); } diff --git a/mlx/backend/cuda/random.cu b/mlx/backend/cuda/random.cu index 0cb550d56..7221af356 100644 --- a/mlx/backend/cuda/random.cu +++ b/mlx/backend/cuda/random.cu @@ -156,34 +156,39 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(keys); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dim3 grid_dims{num_keys, half_size + odd}; - int64_t total = grid_dims.x * grid_dims.y; - int32_t threads_y = 1; - while ((total / threads_y) >= (1U << 31)) { - threads_y *= 2; - } - int32_t threads_x = cuda::ceil_div(total, threads_y); - auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1); - if (keys.flags().row_contiguous) { - cu::rbitsc<<>>( - keys.data(), - out.data(), - grid_dims, - odd, - bytes_per_key); - } else { - cu::rbits<<>>( - keys.data(), - out.data(), - grid_dims, - odd, - bytes_per_key, - keys.ndim(), - const_param(keys.shape()), - const_param(keys.strides())); - } - }); + dim3 grid_dims{num_keys, half_size + odd}; + int64_t total = grid_dims.x * grid_dims.y; + int32_t threads_y = 1; + while ((total / threads_y) >= (1U << 31)) { + threads_y *= 2; + } + int32_t threads_x = cuda::ceil_div(total, threads_y); + auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1); + auto& stream = encoder.stream(); + if (keys.flags().row_contiguous) { + encoder.add_kernel_node( + cu::rbitsc, + grid, + block, + keys.data(), + out.data(), + grid_dims, + odd, + bytes_per_key); + } else { + encoder.add_kernel_node( + cu::rbits, + grid, + block, + keys.data(), + out.data(), + grid_dims, + odd, + bytes_per_key, + keys.ndim(), + const_param(keys.shape()), + const_param(keys.strides())); + } } } // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index a6ccd5ae9..3419d61cb 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -110,19 +110,20 @@ void all_reduce( intermediate.set_data(allocator::malloc(intermediate.nbytes())); encoder.add_temporary(intermediate); encoder.set_output_array(intermediate); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(dt, [&](auto type_tag) { - dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { - using OP = MLX_GET_TYPE(reduce_type_tag); - using T = cuda_type_t; - using U = typename cu::ReduceResult::type; - auto kernel = cu::all_reduce; - kernel<<>>( - static_cast(indata), - intermediate.data(), - block_step, - insize); - }); + dispatch_all_types(dt, [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; + auto kernel = cu::all_reduce; + encoder.add_kernel_node( + kernel, + blocks, + threads, + static_cast(indata), + intermediate.data(), + block_step, + insize); }); }); @@ -135,16 +136,20 @@ void all_reduce( } encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(dt, [&](auto type_tag) { - dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { - using OP = MLX_GET_TYPE(reduce_type_tag); - using T = cuda_type_t; - using U = typename cu::ReduceResult::type; - auto kernel = cu::all_reduce; - kernel<<>>( - static_cast(indata), out.data(), block_step, insize); - }); + dispatch_all_types(dt, [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; + auto kernel = cu::all_reduce; + encoder.add_kernel_node( + kernel, + blocks, + threads, + static_cast(indata), + out.data(), + block_step, + insize); }); }); } diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index 78f6b93bc..910fa0379 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -214,26 +214,24 @@ void col_reduce_looped( encoder.set_input_array(in); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto type_tag) { - dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { - dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { - using OP = MLX_GET_TYPE(reduce_type_tag); - using T = cuda_type_t; - using U = typename cu::ReduceResult::type; + dispatch_all_types(in.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; + // Cub doesn't like const pointers for vectorized loads. (sigh) + T* indata = const_cast(in.data()); - // Cub doesn't like const pointers for vectorized loads. (sigh) - T* indata = const_cast(in.data()); - - constexpr int N_READS = 4; - constexpr int BM = 32; - constexpr int BN = 32; - dim3 grid = output_grid_for_col_reduce(out, args, BN); - int blocks = BM * BN / N_READS; - auto kernel = - cu::col_reduce_looped; - kernel<<>>(indata, out.data(), args); - }); + constexpr int N_READS = 4; + constexpr int BM = 32; + constexpr int BN = 32; + dim3 grid = output_grid_for_col_reduce(out, args, BN); + int blocks = BM * BN / N_READS; + auto kernel = + cu::col_reduce_looped; + encoder.add_kernel_node( + kernel, grid, blocks, indata, out.data(), args); }); }); }); diff --git a/mlx/backend/cuda/reduce/init_reduce.cu b/mlx/backend/cuda/reduce/init_reduce.cu index 296a4e611..649d80190 100644 --- a/mlx/backend/cuda/reduce/init_reduce.cu +++ b/mlx/backend/cuda/reduce/init_reduce.cu @@ -32,18 +32,16 @@ void init_reduce( } encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto type_tag) { - dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { - using OP = MLX_GET_TYPE(reduce_type_tag); - using T = cuda_type_t; - using U = typename cu::ReduceResult::type; - auto kernel = cu::init_reduce; - dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); - dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1); - grid.x = (grid.x + 1023) / 1024; - kernel<<>>(out.data(), out.size()); - }); + dispatch_all_types(in.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; + auto kernel = cu::init_reduce; + dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); + dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1); + grid.x = (grid.x + 1023) / 1024; + encoder.add_kernel_node(kernel, grid, block, out.data(), out.size()); }); }); } diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index deb4a2f91..e57f18668 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -245,34 +245,32 @@ void row_reduce_simple( // 2 passes. Something like 32 * out.size() and then do a warp reduce. encoder.set_input_array(in); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto type_tag) { - dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { - using OP = MLX_GET_TYPE(reduce_type_tag); - using T = cuda_type_t; - using U = typename cu::ReduceResult::type; + dispatch_all_types(in.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; - // Cub doesn't like const pointers for vectorized loads. (sigh) - T* indata = const_cast(in.data()); + // Cub doesn't like const pointers for vectorized loads. (sigh) + T* indata = const_cast(in.data()); - // Calculate the grid and block dims - size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS; - dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); - int threads = std::min(1024UL, reductions); - threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; - dim3 block(threads, 1, 1); + // Calculate the grid and block dims + size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS; + dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); + int threads = std::min(1024UL, reductions); + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + dim3 block(threads, 1, 1); - // Pick the kernel - auto kernel = cu::row_reduce_simple; - if (grid.x >= 1024) { - grid.x = (grid.x + 1) / 2; - kernel = cu::row_reduce_simple; - } + // Pick the kernel + auto kernel = cu::row_reduce_simple; + if (grid.x >= 1024) { + grid.x = (grid.x + 1) / 2; + kernel = cu::row_reduce_simple; + } - // Launch - kernel<<>>( - indata, out.data(), out.size(), plan.shape.back()); - }); + int size = plan.shape.back(); + encoder.add_kernel_node( + kernel, grid, block, indata, out.data(), out.size(), size); }); }); } @@ -293,43 +291,39 @@ void row_reduce_looped( encoder.set_input_array(in); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto type_tag) { - dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { - using OP = MLX_GET_TYPE(reduce_type_tag); - using T = cuda_type_t; - using U = typename cu::ReduceResult::type; + dispatch_all_types(in.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; + // Cub doesn't like const pointers for vectorized loads. (sigh) + T* indata = const_cast(in.data()); - // Cub doesn't like const pointers for vectorized loads. (sigh) - T* indata = const_cast(in.data()); + // Calculate the grid and block dims + args.sort_access_pattern(in, axes); + dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); + size_t reductions = (args.row_size + N_READS - 1) / N_READS; + int threads = std::min(1024UL, reductions); + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + dim3 block(threads, 1, 1); - // Calculate the grid and block dims - args.sort_access_pattern(in, axes); - dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); - size_t reductions = (args.row_size + N_READS - 1) / N_READS; - int threads = std::min(1024UL, reductions); - threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; - dim3 block(threads, 1, 1); - - // Pick the kernel - auto kernel = cu::row_reduce_looped; - dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { - dispatch_block_dim(threads, [&](auto threads_constant) { - kernel = cu::row_reduce_looped< - T, - U, - OP, - reduce_ndim.value, - threads_constant.value, - N_READS>; - block.x = threads_constant.value; - }); + // Pick the kernel + auto kernel = cu::row_reduce_looped; + dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { + dispatch_block_dim(threads, [&](auto threads_constant) { + kernel = cu::row_reduce_looped< + T, + U, + OP, + reduce_ndim.value, + threads_constant.value, + N_READS>; + block.x = threads_constant.value; }); - - // Launch - kernel<<>>( - indata, out.data(), out.size(), args); }); + + encoder.add_kernel_node( + kernel, grid, block, indata, out.data(), out.size(), args); }); }); } diff --git a/mlx/backend/cuda/rms_norm.cu b/mlx/backend/cuda/rms_norm.cu index fc8f4f490..5ee1d3386 100644 --- a/mlx/backend/cuda/rms_norm.cu +++ b/mlx/backend/cuda/rms_norm.cu @@ -224,21 +224,21 @@ void RMSNorm::eval_gpu( encoder.set_input_array(x); encoder.set_input_array(w); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) { - constexpr uint32_t N_READS = 4; - dispatch_block_dim( - cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - using DataType = cuda_type_t; - auto kernel = cu::rms_norm; - kernel<<>>( - x.data(), - w.data(), - out.data(), - eps_, - axis_size, - w_stride); - }); + dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) { + constexpr uint32_t N_READS = 4; + dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::rms_norm; + encoder.add_kernel_node( + kernel, + n_rows, + block_dim(), + x.data(), + w.data(), + out.data(), + eps_, + axis_size, + w_stride); }); }); } @@ -253,20 +253,24 @@ void RMSNormVJP::eval_gpu( // Ensure row contiguity. We could relax this step by checking that the array // is contiguous (no broadcasts or holes) and that the input strides are the // same as the cotangent strides but for now this is simpler. - auto check_input = [&s](const array& x) -> std::pair { + auto check_input = [&s](const array& x, bool& copied) { if (x.flags().row_contiguous) { - return {x, false}; + copied = false; + return x; } + copied = true; array x_copy(x.shape(), x.dtype(), nullptr, {}); copy_gpu(x, x_copy, CopyType::General, s); - return {x_copy, true}; + return x_copy; }; bool donate_x = inputs[0].is_donatable(); bool donate_g = inputs[2].is_donatable(); - auto [x, copied] = check_input(inputs[0]); + bool copied; + auto x = check_input(inputs[0], copied); donate_x |= copied; const array& w = inputs[1]; - auto [g, g_copied] = check_input(inputs[2]); + bool g_copied; + auto g = check_input(inputs[2], g_copied); donate_g |= g_copied; array& gx = outputs[0]; array& gw = outputs[1]; @@ -310,30 +314,31 @@ void RMSNormVJP::eval_gpu( encoder.set_input_array(g); encoder.set_output_array(gx); encoder.set_output_array(gw_temp); - encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) { - dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) { - dispatch_bool(has_w, [&](auto has_w_constant) { - constexpr int N_READS = 4; - dispatch_block_dim( - cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - using DataType = cuda_type_t; - constexpr int N_READS = 4; - auto kernel = cu::rms_norm_vjp< - DataType, - has_w_constant.value, - block_dim(), - N_READS>; - kernel<<>>( - x.data(), - w.data(), - g.data(), - gx.data(), - gw_temp.data(), - eps_, - axis_size, - w_stride); - }); - }); + dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) { + dispatch_bool(has_w, [&](auto has_w_constant) { + constexpr int N_READS = 4; + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + constexpr int N_READS = 4; + auto kernel = cu::rms_norm_vjp< + DataType, + has_w_constant.value, + block_dim(), + N_READS>; + encoder.add_kernel_node( + kernel, + n_rows, + block_dim(), + x.data(), + w.data(), + g.data(), + gx.data(), + gw_temp.data(), + eps_, + axis_size, + w_stride); + }); }); }); diff --git a/mlx/backend/cuda/rope.cu b/mlx/backend/cuda/rope.cu index bb9618fc4..517cddfe0 100644 --- a/mlx/backend/cuda/rope.cu +++ b/mlx/backend/cuda/rope.cu @@ -308,76 +308,89 @@ void RoPE::eval_gpu( auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(donated ? out : in); encoder.set_input_array(offset); + if (with_freqs) { + encoder.set_input_array(inputs[2]); + } encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_float_types(out.dtype(), "rope", [&](auto type_tag) { - dispatch_bool(traditional_, [&](auto traditional) { - dispatch_bool(forward_, [&](auto forward) { - using DataType = cuda_type_t; - if (single && !with_freqs) { - auto kernel = - cu::rope_single; - uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); - auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); - kernel<<>>( - (donated ? out : in).data(), - out.data(), - offset.data(), - scale_, - std::log2(base_), - mat_size, - dims); - } else if (single) { - auto kernel = cu:: - rope_single_freqs; - uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); - auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); - kernel<<>>( - (donated ? out : in).data(), - out.data(), - offset.data(), - inputs[2].data(), - scale_, - mat_size, - dims, - inputs[2].strides(0)); - } else if (with_freqs) { - auto kernel = - cu::rope_freqs; - uint3 dims = - make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); - dims.z = (dims.z + 3) / 4; - auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); - kernel<<>>( - (donated ? out : in).data(), - out.data(), - offset.data(), - inputs[2].data(), - scale_, - std::log2(base_), - strides, - out_strides, - in.size() / mat_size, - dims, - inputs[2].strides(0)); - } else { - auto kernel = cu::rope; - uint3 dims = - make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); - dims.z = (dims.z + 3) / 4; - auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); - kernel<<>>( - (donated ? out : in).data(), - out.data(), - offset.data(), - scale_, - std::log2(base_), - strides, - out_strides, - in.size() / mat_size, - dims); - } - }); + dispatch_float_types(out.dtype(), "rope", [&](auto type_tag) { + dispatch_bool(traditional_, [&](auto traditional) { + dispatch_bool(forward_, [&](auto forward) { + using DataType = cuda_type_t; + if (single && !with_freqs) { + auto kernel = + cu::rope_single; + uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); + auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); + encoder.add_kernel_node( + kernel, + grid, + block, + (donated ? out : in).data(), + out.data(), + offset.data(), + scale_, + std::log2(base_), + mat_size, + dims); + } else if (single) { + auto kernel = + cu::rope_single_freqs; + uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); + auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); + encoder.add_kernel_node( + kernel, + grid, + block, + (donated ? out : in).data(), + out.data(), + offset.data(), + inputs[2].data(), + scale_, + mat_size, + dims, + inputs[2].strides(0)); + } else if (with_freqs) { + auto kernel = + cu::rope_freqs; + uint3 dims = + make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); + dims.z = (dims.z + 3) / 4; + auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); + encoder.add_kernel_node( + kernel, + grid, + block, + (donated ? out : in).data(), + out.data(), + offset.data(), + inputs[2].data(), + scale_, + std::log2(base_), + strides, + out_strides, + in.size() / mat_size, + dims, + inputs[2].strides(0)); + } else { + auto kernel = cu::rope; + uint3 dims = + make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); + dims.z = (dims.z + 3) / 4; + auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); + encoder.add_kernel_node( + kernel, + grid, + block, + (donated ? out : in).data(), + out.data(), + offset.data(), + scale_, + std::log2(base_), + strides, + out_strides, + in.size() / mat_size, + dims); + } }); }); }); diff --git a/mlx/backend/cuda/softmax.cu b/mlx/backend/cuda/softmax.cu index af9ddf214..fd807bd8d 100644 --- a/mlx/backend/cuda/softmax.cu +++ b/mlx/backend/cuda/softmax.cu @@ -141,19 +141,21 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(in); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) { - constexpr int N_READS = 4; - dispatch_block_dim( - cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - using DataType = cuda_type_t; - auto kernel = cu::softmax; - if (precise) { - kernel = cu::softmax; - } - kernel<<>>( - in.data(), out.data(), axis_size); - }); + dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) { + constexpr int N_READS = 4; + dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::softmax; + if (precise) { + kernel = cu::softmax; + } + encoder.add_kernel_node( + kernel, + n_rows, + block_dim(), + in.data(), + out.data(), + axis_size); }); }); } diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index 2c5599bed..379c55706 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -50,32 +50,6 @@ array swapaxes_in_eval(const array& in, int axis1, int axis2) { return out; } -template -void segmented_sort_pairs(cu::CommandEncoder& encoder, Args&&... args) { - // Allocate temporary storage. - size_t size; - CHECK_CUDA_ERROR( - cub::DeviceSegmentedSort::StableSortPairs(nullptr, size, args...)); - array temp(allocator::malloc(size), {static_cast(size)}, uint8); - encoder.add_temporary(temp); - // Run op. - CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs( - temp.data(), size, args...)); -} - -template -void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) { - // Allocate temporary storage. - size_t size; - CHECK_CUDA_ERROR( - cub::DeviceSegmentedSort::StableSortKeys(nullptr, size, args...)); - array temp(allocator::malloc(size), {static_cast(size)}, uint8); - encoder.add_temporary(temp); - // Run op. - CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys( - temp.data(), size, args...)); -} - struct OffsetTransform { int nsort; @@ -113,57 +87,94 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { encoder.set_input_array(in); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto type_tag) { - using CTYPE = MLX_GET_TYPE(type_tag); - if constexpr (!std::is_same_v) { - using Type = cuda_type_t; - auto offsets = thrust::make_transform_iterator( - thrust::make_counting_iterator(0), OffsetTransform{nsort}); - if (argsort) { - // Indices in the sorted dimension. - array indices( - allocator::malloc(out.nbytes()), in.shape(), out.dtype()); - encoder.add_temporary(indices); - thrust::transform( - cu::thrust_policy(stream), - thrust::counting_iterator(0), - thrust::counting_iterator(indices.data_size()), - thrust::device_pointer_cast(indices.data()), - ModOp{static_cast(nsort)}); + dispatch_all_types(in.dtype(), [&](auto type_tag) { + using CTYPE = MLX_GET_TYPE(type_tag); + auto& stream = encoder.stream(); + if constexpr (!std::is_same_v) { + using Type = cuda_type_t; + auto offsets = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), OffsetTransform{nsort}); + if (argsort) { + // Indices in the sorted dimension. + array indices(allocator::malloc(out.nbytes()), in.shape(), out.dtype()); + encoder.add_temporary(indices); - // In argsort though we don't need the result of sorted values, the - // API requires us to provide an array to store it. - array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype()); - encoder.add_temporary(discard); + // In argsort though we don't need the result of sorted values, the + // API requires us to provide an array to store it. + array discard(allocator::malloc(in.nbytes()), in.shape(), in.dtype()); + encoder.add_temporary(discard); - segmented_sort_pairs( - encoder, - in.data(), - discard.data(), - indices.data(), - out.data(), - in.data_size(), - in.data_size() / nsort, - offsets, - offsets + 1, - stream); - } else { - segmented_sort( - encoder, - in.data(), - out.data(), - in.data_size(), - in.data_size() / nsort, - offsets, - offsets + 1, - stream); - } + size_t size; + CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs( + nullptr, + size, + in.data(), + discard.data(), + indices.data(), + out.data(), + in.data_size(), + in.data_size() / nsort, + offsets, + offsets + 1, + stream)); + + array temp(allocator::malloc(size), {static_cast(size)}, uint8); + encoder.add_temporary(temp); + + // Start capturing after allocations + auto capture = encoder.capture_context(); + thrust::transform( + cu::thrust_policy(stream), + thrust::counting_iterator(0), + thrust::counting_iterator(indices.data_size()), + thrust::device_pointer_cast(indices.data()), + ModOp{static_cast(nsort)}); + + CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs( + temp.data(), + size, + in.data(), + discard.data(), + indices.data(), + out.data(), + in.data_size(), + in.data_size() / nsort, + offsets, + offsets + 1, + stream)); } else { - throw std::runtime_error( - "CUDA backend does not support sorting complex numbers"); + size_t size; + CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys( + nullptr, + size, + in.data(), + out.data(), + in.data_size(), + in.data_size() / nsort, + offsets, + offsets + 1, + stream)); + + array temp(allocator::malloc(size), {static_cast(size)}, uint8); + encoder.add_temporary(temp); + + // Start capturing after allocations + auto capture = encoder.capture_context(); + CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys( + temp.data(), + size, + in.data(), + out.data(), + in.data_size(), + in.data_size() / nsort, + offsets, + offsets + 1, + stream)); } - }); + } else { + throw std::runtime_error( + "CUDA backend does not support sorting complex numbers"); + } }); if (!is_segmented_sort) { diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index 1d6535100..aa6523f27 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -91,73 +91,80 @@ void ternary_op_gpu_inplace( encoder.set_input_array(b); encoder.set_input_array(c); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(out.dtype(), [&](auto type_tag) { - using DType = cuda_type_t; + dispatch_all_types(out.dtype(), [&](auto type_tag) { + using DType = cuda_type_t; - auto topt = get_ternary_op_type(a, b, c); - if (topt == TernaryOpType::General) { - dispatch_bool( - a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || - c.data_size() > INT32_MAX || out.data_size() > INT32_MAX, - [&](auto large) { - using IdxT = std::conditional_t; - Shape shape; - std::vector strides; - std::tie(shape, strides) = collapse_contiguous_dims(a, b, c, out); - auto& a_strides = strides[0]; - auto& b_strides = strides[1]; - auto& c_strides = strides[2]; - int ndim = shape.size(); - if (ndim <= 3) { - dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto kernel = - cu::ternary_g_nd; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large()); - kernel<<>>( - a.data(), - b.data(), - c.data(), - out.data(), - out.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides), - const_param(c_strides)); - }); - } else { - auto kernel = cu::ternary_g; + auto topt = get_ternary_op_type(a, b, c); + if (topt == TernaryOpType::General) { + dispatch_bool( + a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || + c.data_size() > INT32_MAX || out.data_size() > INT32_MAX, + [&](auto large) { + using IdxT = std::conditional_t; + Shape shape; + std::vector strides; + std::tie(shape, strides) = collapse_contiguous_dims(a, b, c, out); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + auto& c_strides = strides[2]; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = + cu::ternary_g_nd; auto [num_blocks, block_dims] = get_launch_args(kernel, out, large()); - kernel<<>>( + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, a.data(), b.data(), c.data(), out.data(), - out.data_size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides), - const_param(c_strides), - ndim); - } - }); - } else { - dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { - using IdxT = std::conditional_t; - auto kernel = cu::ternary_v; - auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), large()); - kernel<<>>( - a.data(), - b.data(), - c.data(), - out.data(), - out.data_size()); - }); - } - }); + out.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + const_param(c_strides)); + }); + } else { + auto kernel = cu::ternary_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + a.data(), + b.data(), + c.data(), + out.data(), + out.data_size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + const_param(c_strides), + ndim); + } + }); + } else { + dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { + using IdxT = std::conditional_t; + auto kernel = cu::ternary_v; + auto [num_blocks, block_dims] = get_launch_args( + kernel, out.data_size(), out.shape(), out.strides(), large()); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + a.data(), + b.data(), + c.data(), + out.data(), + out.data_size()); + }); + } }); } diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index 74251d1f6..3f1a62d24 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -9,14 +9,38 @@ #include "mlx/dtype_utils.h" #include "mlx/primitives.h" +#include #include -#include -#include namespace mlx::core { namespace cu { +namespace cg = cooperative_groups; + +template +__global__ void unary_v(const In* in, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = Op{}(in[index]); + } +} + +template +__global__ void unary_g( + const In* in, + Out* out, + IdxT size, + const __grid_constant__ Shape shape, + const __grid_constant__ Strides strides, + int ndim) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto idx = elem_to_loc_4d(index, shape.data(), strides.data(), ndim); + out[index] = Op{}(in[idx]); + } +} + template constexpr bool supports_unary_op() { if (std::is_same_v || std::is_same_v || @@ -71,38 +95,61 @@ void unary_op_gpu_inplace( if (in.size() == 0) { return; } + bool contig = in.flags().contiguous; + bool large; + if (!contig) { + large = in.data_size() > INT32_MAX || out.size() > INT32_MAX; + } else { + large = in.data_size() > UINT32_MAX; + } auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(in); encoder.set_output_array(out); - encoder.launch_kernel([&](cudaStream_t stream) { - dispatch_all_types(in.dtype(), [&](auto in_type_tag) { - dispatch_all_types(out.dtype(), [&](auto out_type_tag) { - using CTYPE_IN = MLX_GET_TYPE(in_type_tag); - using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); - if constexpr (cu::supports_unary_op()) { + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + using CTYPE_IN = MLX_GET_TYPE(in_type_tag); + using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); + if constexpr (cu::supports_unary_op()) { + dispatch_bool(large, [&](auto large) { + using IdxT = std::conditional_t; using InType = cuda_type_t; using OutType = cuda_type_t; - auto policy = cu::thrust_policy(stream); - auto in_ptr = thrust::device_pointer_cast(in.data()); - auto out_ptr = thrust::device_pointer_cast(out.data()); - if (in.flags().contiguous) { - thrust::transform( - policy, in_ptr, in_ptr + in.data_size(), out_ptr, Op()); + using IdxT = std::conditional_t; + if (contig) { + auto kernel = cu::unary_v; + auto [num_blocks, block_dims] = get_launch_args( + kernel, out.data_size(), out.shape(), out.strides(), large); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + in.data(), + out.data(), + out.data_size()); } else { auto [shape, strides] = collapse_contiguous_dims(in); - auto [in_begin, in_end] = cu::make_general_iterators( - in_ptr, in.size(), shape, strides); - thrust::transform(policy, in_begin, in_end, out_ptr, Op()); + auto kernel = cu::unary_g; + auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + in.data(), + out.data(), + out.data_size(), + const_param(shape), + const_param(strides), + shape.size()); } - } else { - throw std::runtime_error(fmt::format( - "Can not do unary op {} on input of {} with output of {}.", - op, - dtype_to_string(in.dtype()), - dtype_to_string(out.dtype()))); - } - }); + }); + } else { + throw std::runtime_error(fmt::format( + "Can not do unary op {} on input of {} with output of {}.", + op, + dtype_to_string(in.dtype()), + dtype_to_string(out.dtype()))); + } }); }); } diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp index 35731f6eb..cc05428a8 100644 --- a/mlx/backend/cuda/utils.cpp +++ b/mlx/backend/cuda/utils.cpp @@ -24,6 +24,14 @@ void check_cuda_error(const char* name, cudaError_t err) { } } +void check_cuda_error(const char* name, CUresult err) { + if (err != CUDA_SUCCESS) { + const char* err_str = "Unknown error"; + cuGetErrorString(err, &err_str); + throw std::runtime_error(fmt::format("{} failed: {}", name, err_str)); + } +} + const char* dtype_to_cuda_type(const Dtype& dtype) { switch (dtype) { case bool_: diff --git a/mlx/backend/cuda/utils.h b/mlx/backend/cuda/utils.h index 6d98cdcd5..bfb02c5b6 100644 --- a/mlx/backend/cuda/utils.h +++ b/mlx/backend/cuda/utils.h @@ -4,6 +4,7 @@ #pragma once +#include #include namespace mlx::core { @@ -33,6 +34,7 @@ class CudaStream { // Throw exception if the cuda API does not succeed. void check_cuda_error(const char* name, cudaError_t err); +void check_cuda_error(const char* name, CUresult err); // The macro version that prints the command that failed. #define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd)) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 144f9a880..ff3208e1e 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -688,7 +688,7 @@ array solve(const array& a, const array& b, StreamOrDevice s /* = {} */) { perm = expand_dims(perm, -1, s); take_axis -= 1; } - auto pb = take_along_axis(b, perm, take_axis); + auto pb = take_along_axis(b, perm, take_axis, s); auto y = solve_triangular(luf[1], pb, /* upper = */ false, s); return solve_triangular(luf[2], y, /* upper = */ true, s); } diff --git a/python/tests/test_load.py b/python/tests/test_load.py index 35f7016c5..840d3b471 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -391,9 +391,11 @@ class TestLoad(mlx_tests.MLXTestCase): scale = mx.array(2.0) y = mx.load(save_file) mx.eval(y) + mx.synchronize() load_only = mx.get_peak_memory() y = mx.load(save_file) * scale mx.eval(y) + mx.synchronize() load_with_binary = mx.get_peak_memory() self.assertEqual(load_only, load_with_binary)