diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index d96bb8812..ad979a13f 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -8,6 +8,7 @@ target_sources( PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/binary.cu + ${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index e8e8a8988..9c437cde9 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -125,13 +125,12 @@ constexpr bool supports_binary_op() { template void binary_op_gpu_inplace( const std::vector& inputs, - std::vector& outputs, + array& out, std::string_view op, const Stream& s) { assert(inputs.size() > 1); const auto& a = inputs[0]; const auto& b = inputs[1]; - auto& out = outputs[0]; if (out.size() == 0) { return; } @@ -146,7 +145,6 @@ void binary_op_gpu_inplace( if constexpr (cu::supports_binary_op()) { using InType = cuda_type_t; using OutType = cuda_type_t; - auto bopt = get_binary_op_type(a, b); if (bopt == BinaryOpType::General) { auto [shape, strides] = collapse_contiguous_dims(a, b, out); @@ -219,20 +217,6 @@ void binary_op_gpu_inplace( }); } -template -void binary_op_gpu( - const std::vector& inputs, - std::vector& outputs, - std::string_view op, - const Stream& s) { - auto& a = inputs[0]; - auto& b = inputs[1]; - auto bopt = get_binary_op_type(a, b); - set_binary_op_output_data(a, b, outputs[0], bopt); - set_binary_op_output_data(a, b, outputs[1], bopt); - binary_op_gpu_inplace(inputs, outputs, op, s); -} - template void binary_op_gpu( const std::vector& inputs, @@ -243,8 +227,7 @@ void binary_op_gpu( auto& b = inputs[1]; auto bopt = get_binary_op_type(a, b); set_binary_op_output_data(a, b, out, bopt); - std::vector outputs{out}; - binary_op_gpu_inplace(inputs, outputs, op, s); + binary_op_gpu_inplace(inputs, out, op, s); } #define BINARY_GPU(func) \ @@ -254,14 +237,6 @@ void binary_op_gpu( binary_op_gpu(inputs, out, get_primitive_string(this), s); \ } -#define BINARY_GPU_MULTI(func) \ - void func::eval_gpu( \ - const std::vector& inputs, std::vector& outputs) { \ - nvtx3::scoped_range r(#func "::eval_gpu"); \ - auto& s = outputs[0].primitive().stream(); \ - binary_op_gpu(inputs, outputs, get_primitive_string(this), s); \ - } - BINARY_GPU(Add) BINARY_GPU(ArcTan2) BINARY_GPU(Divide) diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu new file mode 100644 index 000000000..3047e39f0 --- /dev/null +++ b/mlx/backend/cuda/binary_two.cu @@ -0,0 +1,248 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/binary.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/binary_ops.cuh" +#include "mlx/backend/cuda/device/cucomplex_math.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void +binary_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto out = Op{}(a[0], b[0]); + out_a[0] = out[0]; + out_b[0] = out[1]; + } +} + +template +__global__ void +binary_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto out = Op{}(a[0], b[index]); + out_a[index] = out[0]; + out_b[index] = out[1]; + } +} + +template +__global__ void +binary_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto out = Op{}(a[index], b[0]); + out_a[index] = out[0]; + out_b[index] = out[1]; + } +} + +template +__global__ void +binary_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto out = Op{}(a[index], b[index]); + out_a[index] = out[0]; + out_b[index] = out[1]; + } +} + +template +__global__ void binary_g_nd( + const In* a, + const In* b, + Out* out_a, + Out* out_b, + IdxT size, + const __grid_constant__ cuda::std::array shape, + const __grid_constant__ cuda::std::array a_strides, + const __grid_constant__ cuda::std::array b_strides) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [a_idx, b_idx] = elem_to_loc_nd( + index, shape.data(), a_strides.data(), b_strides.data()); + auto out = Op{}(a[a_idx], b[b_idx]); + out_a[index] = out[0]; + out_b[index] = out[1]; + } +} + +template +__global__ void binary_g( + const In* a, + const In* b, + Out* out_a, + Out* out_b, + IdxT size, + const __grid_constant__ Shape shape, + const __grid_constant__ Strides a_strides, + const __grid_constant__ Strides b_strides, + int ndim) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [a_idx, b_idx] = elem_to_loc_4d( + index, shape.data(), a_strides.data(), b_strides.data(), ndim); + auto out = Op{}(a[a_idx], b[b_idx]); + out_a[index] = out[0]; + out_b[index] = out[1]; + } +} + +template +constexpr bool supports_binary_op() { + if (std::is_same_v) { + return std::is_same_v && + (std::is_integral_v || is_floating_v); + } + return false; +} + +} // namespace cu + +template +void binary_op_gpu_inplace( + const std::vector& inputs, + std::vector& outputs, + std::string_view op, + const Stream& s) { + assert(inputs.size() > 1); + const auto& a = inputs[0]; + const auto& b = inputs[1]; + auto& out_a = outputs[0]; + auto& out_b = outputs[1]; + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out_a, bopt); + set_binary_op_output_data(a, b, out_b, bopt); + + if (out_a.size() == 0) { + return; + } + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out_a); + encoder.set_output_array(out_b); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, { + MLX_SWITCH_ALL_TYPES(out_a.dtype(), CTYPE_OUT, { + if constexpr (cu::supports_binary_op()) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + + auto bopt = get_binary_op_type(a, b); + if (bopt == BinaryOpType::General) { + auto [shape, strides] = collapse_contiguous_dims(a, b, out_a); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + bool large = a.data_size() > INT32_MAX || + b.data_size() > INT32_MAX || out_a.data_size() > INT32_MAX; + MLX_SWITCH_BOOL(large, LARGE, { + using IdxT = std::conditional_t; + int ndim = shape.size(); + if (ndim <= 3) { + MLX_SWITCH_1_2_3(ndim, NDIM, { + auto kernel = + &cu::binary_g_nd; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out_a, large); + kernel<<>>( + a.data(), + b.data(), + out_a.data(), + out_b.data(), + out_a.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides)); + }); + } else { + auto kernel = cu::binary_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out_a, large); + kernel<<>>( + a.data(), + b.data(), + out_a.data(), + out_b.data(), + out_a.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + ndim); + } + }); + } else { + MLX_SWITCH_BOOL(out_a.data_size() > UINT32_MAX, LARGE, { + using IdxT = std::conditional_t; + auto kernel = cu::binary_ss; + if (bopt == BinaryOpType::ScalarVector) { + kernel = cu::binary_sv; + } else if (bopt == BinaryOpType::VectorScalar) { + kernel = cu::binary_vs; + } else if (bopt == BinaryOpType::VectorVector) { + kernel = cu::binary_vv; + } + auto [num_blocks, block_dims] = get_launch_args( + kernel, + out_a.data_size(), + out_a.shape(), + out_a.strides(), + LARGE); + kernel<<>>( + a.data(), + b.data(), + out_a.data(), + out_b.data(), + out_a.data_size()); + }); + } + } else { + throw std::runtime_error(fmt::format( + "Can not do binary op {} on inputs of {} with result of {}.", + op, + dtype_to_string(a.dtype()), + dtype_to_string(out_a.dtype()))); + } + }); + }); + }); +} + +template +void binary_op_gpu( + const std::vector& inputs, + std::vector& outputs, + std::string_view op, + const Stream& s) { + auto& a = inputs[0]; + auto& b = inputs[1]; + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, outputs[0], bopt); + set_binary_op_output_data(a, b, outputs[1], bopt); + binary_op_gpu_inplace(inputs, outputs, op, s); +} + +void DivMod::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + nvtx3::scoped_range r("DivMod::eval_gpu"); + auto& s = outputs[0].primitive().stream(); + binary_op_gpu(inputs, outputs, get_primitive_string(this), s); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/device/binary_ops.cuh b/mlx/backend/cuda/device/binary_ops.cuh index ca5ac35e6..dc4f8e7bb 100644 --- a/mlx/backend/cuda/device/binary_ops.cuh +++ b/mlx/backend/cuda/device/binary_ops.cuh @@ -22,7 +22,7 @@ struct FloorDivide { if constexpr (cuda::std::is_integral_v) { return x / y; } else { - return trunc(x / y); + return truncf(x / y); } } }; @@ -132,7 +132,7 @@ struct LogAddExp { cuda::std::numeric_limits::quiet_NaN(), cuda::std::numeric_limits::quiet_NaN()}; } - constexpr float inf = cuda::std::numeric_limits::infinity(); + float inf = cuda::std::numeric_limits::infinity(); auto maxval = x > y ? x : y; auto minval = x < y ? x : y; if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf) diff --git a/mlx/backend/cuda/device/config.h b/mlx/backend/cuda/device/config.h index 0933cc8b5..5a3402905 100644 --- a/mlx/backend/cuda/device/config.h +++ b/mlx/backend/cuda/device/config.h @@ -5,7 +5,7 @@ #pragma once // The maximum dimensions of shape/strides passed as kernel parameters. -#define MAX_NDIM 8 +#define MAX_NDIM 10 // All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in // warpSize variable exists, using it would prevent compile-time optimizations. diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index c2362bea2..e32befc9c 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -71,10 +71,8 @@ bool fast::ScaledDotProductAttention::use_fallback( throw std::runtime_error(#func " has no CUDA implementation."); \ } -NO_GPU(ArgPartition) NO_GPU(BlockMaskedMM) NO_GPU(Convolution) -NO_GPU_MULTI(DivMod) NO_GPU(DynamicSlice) NO_GPU(DynamicSliceUpdate) NO_GPU(FFT) @@ -83,7 +81,6 @@ NO_GPU(GatherQMM) NO_GPU(Hadamard) NO_GPU(Load) NO_GPU_MULTI(LUF) -NO_GPU(Partition) NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) NO_GPU(Scan) diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index e1c2e8530..154ca5f32 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -86,7 +86,6 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { axis += in.ndim(); } int nsort = in.shape(axis); - int nsegments = in.data_size() / nsort; int last_dim = in.ndim() - 1; // If we are not sorting the innermost dimension of a contiguous array, @@ -100,7 +99,11 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype()); encoder.add_temporary(out); } else { - out.set_data(allocator::malloc(out.nbytes())); + out.set_data( + allocator::malloc(in.data_size() * out.itemsize()), + in.data_size(), + in.strides(), + in.flags()); } encoder.launch_kernel([&](cudaStream_t stream) { @@ -134,7 +137,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { indices.data(), out.data(), in.data_size(), - nsegments, + in.data_size() / nsort, offsets, offsets + 1, stream); @@ -144,7 +147,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { in.data(), out.data(), in.data_size(), - nsegments, + in.data_size() / nsort, offsets, offsets + 1, stream); @@ -177,4 +180,14 @@ void Sort::eval_gpu(const std::vector& inputs, array& out) { gpu_sort(stream(), inputs[0], out, axis_, false); } +void ArgPartition::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("ArgPartition::eval_gpu"); + gpu_sort(stream(), inputs[0], out, axis_, true); +} + +void Partition::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Partition::eval_gpu"); + gpu_sort(stream(), inputs[0], out, axis_, false); +} + } // namespace mlx::core diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 23c5fb19c..36388c3c5 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -1,10 +1,8 @@ cuda_skip = { "TestArray.test_api", - "TestAutograd.test_update_state", "TestBF16.test_arg_reduction_ops", "TestBF16.test_reduction_ops", "TestBlas.test_complex_gemm", - "TestCompile.test_compile_dynamic_dims", "TestEinsum.test_ellipses", "TestEinsum.test_opt_einsum_test_cases", "TestLoad.test_load_f8_e4m3", @@ -14,24 +12,14 @@ cuda_skip = { "TestLayers.test_quantized_embedding", "TestLayers.test_sin_pe", "TestLayers.test_upsample", - "TestOps.test_array_equal", "TestOps.test_complex_ops", "TestOps.test_dynamic_slicing", "TestOps.test_softmax", - "TestOps.test_sort", - "TestOps.test_tile", "TestReduce.test_axis_permutation_sums", "TestReduce.test_dtypes", "TestReduce.test_expand_sums", "TestReduce.test_many_reduction_axes", "TestUpsample.test_torch_upsample", - # DivMod NYI - "TestOps.test_divmod", - "TestEval.test_multi_output_eval_during_transform", - # Partition NYI - "TestAutograd.test_topk_grad", - "TestOps.test_argpartition", - "TestOps.test_partition", # Block masked matmul NYI "TestBlas.test_block_masked_matmul", # Gather matmul NYI