From bc53f8293f88bd94ca38ef6642cb487e240165db Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 16 Jun 2025 13:14:46 -0700 Subject: [PATCH] Cuda bug fixes 2 (#2298) * more bug fixes * more bug fixes * format --- mlx/backend/cuda/binary.cu | 14 +-- mlx/backend/cuda/compiled.cpp | 2 + mlx/backend/cuda/device/binary_ops.cuh | 22 +++++ mlx/backend/cuda/device/ternary_ops.cuh | 1 + mlx/backend/cuda/device/utils.cuh | 33 ++++++-- mlx/backend/cuda/indexing.cpp | 50 +++++------ mlx/backend/cuda/ternary.cu | 6 +- mlx/backend/cuda/unary.cu | 7 +- mlx/backend/cuda/utils.cpp | 3 + python/tests/cuda_skip.py | 108 +++++++++++------------- python/tests/test_losses.py | 4 +- 11 files changed, 143 insertions(+), 107 deletions(-) diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index d4df06f18..e8e8a8988 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -101,10 +101,12 @@ constexpr bool supports_binary_op() { return std::is_same_v && std::is_same_v; } if (std::is_same_v) { - return std::is_same_v && - (is_floating_v || std::is_same_v); + return std::is_same_v && is_inexact_v; } - if (std::is_same_v || std::is_same_v) { + if (std::is_same_v) { + return std::is_same_v && is_inexact_v; + } + if (std::is_same_v) { return std::is_same_v && is_floating_v; } if (std::is_same_v || std::is_same_v || @@ -150,10 +152,10 @@ void binary_op_gpu_inplace( auto [shape, strides] = collapse_contiguous_dims(a, b, out); auto& a_strides = strides[0]; auto& b_strides = strides[1]; - bool large = a.data_size() > UINT32_MAX || - b.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX; + bool large = a.data_size() > INT32_MAX || + b.data_size() > INT32_MAX || out.data_size() > INT32_MAX; MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; + using IdxT = std::conditional_t; int ndim = shape.size(); if (ndim <= 3) { MLX_SWITCH_1_2_3(ndim, NDIM, { diff --git a/mlx/backend/cuda/compiled.cpp b/mlx/backend/cuda/compiled.cpp index a6b8223e0..1aa7ecb92 100644 --- a/mlx/backend/cuda/compiled.cpp +++ b/mlx/backend/cuda/compiled.cpp @@ -130,11 +130,13 @@ struct FusedKernelBuilder { constexpr const char* g_jit_includes = R"( #include "mlx/backend/cuda/device/binary_ops.cuh" +#include "mlx/backend/cuda/device/ternary_ops.cuh" #include "mlx/backend/cuda/device/unary_ops.cuh" #include "mlx/backend/cuda/device/utils.cuh" #include +#define inf cuda::std::numeric_limits::infinity() )"; void Compiled::eval_gpu( diff --git a/mlx/backend/cuda/device/binary_ops.cuh b/mlx/backend/cuda/device/binary_ops.cuh index b96a7f9cc..ca5ac35e6 100644 --- a/mlx/backend/cuda/device/binary_ops.cuh +++ b/mlx/backend/cuda/device/binary_ops.cuh @@ -1,6 +1,8 @@ // Copyright © 2025 Apple Inc. +#include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/device/fp16_math.cuh" +#include "mlx/backend/cuda/device/utils.cuh" #include #include @@ -122,6 +124,26 @@ struct LogAddExp { ? maxval : T(float(maxval) + log1p(expf(minval - maxval))); }; + + __device__ cuComplex operator()(cuComplex x, cuComplex y) { + if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) || + isnan(cuCimagf(y))) { + return { + cuda::std::numeric_limits::quiet_NaN(), + cuda::std::numeric_limits::quiet_NaN()}; + } + constexpr float inf = cuda::std::numeric_limits::infinity(); + auto maxval = x > y ? x : y; + auto minval = x < y ? x : y; + if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf) + return maxval; + float m = exp(cuCrealf(minval) - cuCrealf(maxval)); + cuComplex dexp{ + m * cos(cuCimagf(minval) - cuCimagf(maxval)), + m * sin(cuCimagf(minval) - cuCimagf(maxval)), + }; + return maxval + log1p(dexp); + } }; struct Maximum { diff --git a/mlx/backend/cuda/device/ternary_ops.cuh b/mlx/backend/cuda/device/ternary_ops.cuh index d1d008ac5..441845471 100644 --- a/mlx/backend/cuda/device/ternary_ops.cuh +++ b/mlx/backend/cuda/device/ternary_ops.cuh @@ -1,4 +1,5 @@ // Copyright © 2025 Apple Inc. +#pragma once namespace mlx::core::cu { diff --git a/mlx/backend/cuda/device/utils.cuh b/mlx/backend/cuda/device/utils.cuh index 6f9851c94..54d551992 100644 --- a/mlx/backend/cuda/device/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -187,8 +187,8 @@ inline __host__ __device__ cuda::std::tuple elem_to_loc_nd( template inline __host__ __device__ IdxT elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) { - IdxT loc = elem_to_loc_nd<3>(elem, shape, strides); - for (int i = ndim - 1; i >= 3; --i) { + IdxT loc = 0; + for (int i = ndim - 1; i >= 0; --i) { loc += (elem % shape[i]) * IdxT(strides[i]); elem /= shape[i]; } @@ -202,8 +202,9 @@ inline __host__ __device__ cuda::std::tuple elem_to_loc_4d( const int64_t* a_strides, const int64_t* b_strides, int ndim) { - auto [a_loc, b_loc] = elem_to_loc_nd<3>(elem, shape, a_strides, b_strides); - for (int i = ndim - 1; i >= 3; --i) { + IdxT a_loc = 0; + IdxT b_loc = 0; + for (int i = ndim - 1; i >= 0; --i) { int dim_idx = elem % shape[i]; a_loc += dim_idx * a_strides[i]; b_loc += dim_idx * b_strides[i]; @@ -220,9 +221,10 @@ inline __host__ __device__ cuda::std::tuple elem_to_loc_4d( const int64_t* b_strides, const int64_t* c_strides, int ndim) { - auto [a_loc, b_loc, c_loc] = - elem_to_loc_nd<3>(elem, shape, a_strides, b_strides, c_strides); - for (int i = ndim - 1; i >= 3; --i) { + IdxT a_loc = 0; + IdxT b_loc = 0; + IdxT c_loc = 0; + for (int i = ndim - 1; i >= 0; --i) { int dim_idx = elem % shape[i]; a_loc += dim_idx * a_strides[i]; b_loc += dim_idx * b_strides[i]; @@ -336,4 +338,21 @@ struct LoopedElemToLoc<1, false, OffsetT> { } }; +inline __device__ cuComplex log1p(cuComplex in) { + float x = cuCrealf(in); + float y = cuCimagf(in); + float zabs = sqrt(x * x + y * y); + float theta = atan2f(y, x + 1); + if (zabs < 0.5f) { + float r = x * (2 + x) + y * y; + if (r == 0) { // handle underflow + return {x, theta}; + } + return {0.5f * log1pf(r), theta}; + } else { + auto z0 = sqrt((x + 1) * (x + 1) + y * y); + return {log(z0), theta}; + } +} + } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp index 3603605c4..65a175fbd 100644 --- a/mlx/backend/cuda/indexing.cpp +++ b/mlx/backend/cuda/indexing.cpp @@ -65,8 +65,8 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32; int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0; - bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) || - (src.size() > UINT32_MAX) || (out.size() > UINT32_MAX); + bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) || + (src.size() > INT32_MAX) || (out.size() > INT32_MAX); uint32_t slice_size = std::accumulate( slice_sizes_.begin(), slice_sizes_.end(), 1, std::multiplies()); @@ -88,7 +88,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { dtype_to_cuda_type(idx_dtype), nidx, ndim, - large ? "int64_t" : "uint32_t")); + large ? "int64_t" : "int32_t")); } } return std::make_pair(jit_source_gather, std::move(kernel_names)); @@ -99,7 +99,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { if (large) { mod.append_arg(out.size()); } else { - mod.append_arg(out.size()); + mod.append_arg(out.size()); } mod.append_ndim_arg(src.shape()); mod.append_ndim_arg(src.strides()); @@ -115,7 +115,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { dtype_to_cuda_type(idx_dtype), nidx, idx_ndim, - large ? "int64_t" : "uint32_t"); + large ? "int64_t" : "int32_t"); auto& encoder = cu::get_command_encoder(s); for (const auto& in : inputs) { @@ -152,14 +152,14 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32; int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0; - bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) || - (upd.size() > UINT32_MAX) || (out.size() > UINT32_MAX); + bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) || + (upd.size() > INT32_MAX) || (out.size() > INT32_MAX); - uint32_t upd_post_idx_size = std::accumulate( + int32_t upd_post_idx_size = std::accumulate( upd.shape().begin() + idx_ndim, upd.shape().end(), 1, - std::multiplies()); + std::multiplies()); const char* op = g_scatter_ops[reduce_type_]; std::string module_name = fmt::format( @@ -181,7 +181,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { op, nidx, ndim, - large ? "int64_t" : "uint32_t")); + large ? "int64_t" : "int32_t")); } } return std::make_pair(jit_source_scatter, std::move(kernel_names)); @@ -192,7 +192,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { if (large) { mod.append_arg(upd.size()); } else { - mod.append_arg(upd.size()); + mod.append_arg(upd.size()); } mod.append_ndim_arg(upd.shape()); mod.append_ndim_arg(upd.strides()); @@ -200,7 +200,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { if (large) { mod.append_arg(upd_post_idx_size); } else { - mod.append_arg(upd_post_idx_size); + mod.append_arg(upd_post_idx_size); } mod.append_ndim_arg(out.shape()); mod.append_ndim_arg(out.strides()); @@ -215,7 +215,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { op, nidx, idx_ndim, - large ? "int64_t" : "uint32_t"); + large ? "int64_t" : "int32_t"); auto& encoder = cu::get_command_encoder(s); for (const auto& in : inputs) { @@ -238,7 +238,7 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { return; } - bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX; + bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX; std::string module_name = fmt::format( "gather_axis_{}_{}", @@ -258,7 +258,7 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { ndim, contiguous & 1 ? true : false, contiguous & 2 ? true : false, - large ? "int64_t" : "uint32_t")); + large ? "int64_t" : "int32_t")); } } } @@ -283,9 +283,9 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { mod.append_arg(idx_size_axis); mod.append_arg(idx_size_post); } else { - mod.append_arg(idx_size_pre); - mod.append_arg(idx_size_axis); - mod.append_arg(idx_size_post); + mod.append_arg(idx_size_pre); + mod.append_arg(idx_size_axis); + mod.append_arg(idx_size_post); } mod.append_arg(remove_index(idx.shape(), axis_)); mod.append_arg(remove_index(src.strides(), axis_)); @@ -302,7 +302,7 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { src.ndim() - 1, src.flags().row_contiguous, idx.flags().row_contiguous, - large ? "int64_t" : "uint32_t"); + large ? "int64_t" : "int32_t"); auto& encoder = cu::get_command_encoder(s); for (const auto& in : inputs) { @@ -337,7 +337,7 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { return; } - bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX; + bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX; const char* op = reduce_type_ == ScatterAxis::Sum ? "Sum" : "Assign"; std::string module_name = fmt::format( @@ -360,7 +360,7 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { ndim, contiguous & 1 ? true : false, contiguous & 2 ? true : false, - large ? "int64_t" : "uint32_t")); + large ? "int64_t" : "int32_t")); } } } @@ -385,9 +385,9 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { mod.append_arg(idx_size_axis); mod.append_arg(idx_size_post); } else { - mod.append_arg(idx_size_pre); - mod.append_arg(idx_size_axis); - mod.append_arg(idx_size_post); + mod.append_arg(idx_size_pre); + mod.append_arg(idx_size_axis); + mod.append_arg(idx_size_post); } mod.append_arg(remove_index(idx.shape(), axis_)); mod.append_arg(remove_index(upd.strides(), axis_)); @@ -405,7 +405,7 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { idx.ndim() - 1, upd.flags().row_contiguous, idx.flags().row_contiguous, - large ? "int64_t" : "uint32_t"); + large ? "int64_t" : "int32_t"); auto& encoder = cu::get_command_encoder(s); for (const auto& in : inputs) { diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index 02e46afc1..e33af3c80 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -101,10 +101,10 @@ void ternary_op_gpu_inplace( auto& a_strides = strides[0]; auto& b_strides = strides[1]; auto& c_strides = strides[2]; - bool large = a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX || - c.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX; + bool large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || + c.data_size() > INT32_MAX || out.data_size() > INT32_MAX; MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; + using IdxT = std::conditional_t; int ndim = shape.size(); if (ndim <= 3) { MLX_SWITCH_1_2_3(ndim, NDIM, { diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index d2fa96381..e45144eda 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -27,13 +27,12 @@ constexpr bool supports_unary_op() { std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v) { + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { return std::is_same_v && is_floating_v; } if (std::is_same_v || std::is_same_v || - std::is_same_v) { + std::is_same_v || std::is_same_v) { return std::is_same_v && is_inexact_v; } if (std::is_same_v) { diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp index 2f5e2a4c8..4a3d8be30 100644 --- a/mlx/backend/cuda/utils.cpp +++ b/mlx/backend/cuda/utils.cpp @@ -31,6 +31,9 @@ const char* dtype_to_cuda_type(const Dtype& dtype) { if (dtype == bfloat16) { return "__nv_bfloat16"; } + if (dtype == complex64) { + return "cuComplex"; + } #define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \ if (dtype == DTYPE) { \ return #CPP_TYPE; \ diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 0072db192..23c5fb19c 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -1,24 +1,50 @@ cuda_skip = { "TestArray.test_api", - "TestAutograd.test_cumprod_grad", - "TestAutograd.test_slice_grads", - "TestAutograd.test_split_against_slice", - "TestAutograd.test_stop_gradient", - "TestAutograd.test_topk_grad", "TestAutograd.test_update_state", - "TestAutograd.test_vjp", "TestBF16.test_arg_reduction_ops", - "TestBF16.test_binary_ops", "TestBF16.test_reduction_ops", - "TestBlas.test_block_masked_matmul", "TestBlas.test_complex_gemm", + "TestCompile.test_compile_dynamic_dims", + "TestEinsum.test_ellipses", + "TestEinsum.test_opt_einsum_test_cases", + "TestLoad.test_load_f8_e4m3", + "TestMemory.test_memory_info", + "TestLayers.test_group_norm", + "TestLayers.test_pooling", + "TestLayers.test_quantized_embedding", + "TestLayers.test_sin_pe", + "TestLayers.test_upsample", + "TestOps.test_array_equal", + "TestOps.test_complex_ops", + "TestOps.test_dynamic_slicing", + "TestOps.test_softmax", + "TestOps.test_sort", + "TestOps.test_tile", + "TestReduce.test_axis_permutation_sums", + "TestReduce.test_dtypes", + "TestReduce.test_expand_sums", + "TestReduce.test_many_reduction_axes", + "TestUpsample.test_torch_upsample", + # DivMod NYI + "TestOps.test_divmod", + "TestEval.test_multi_output_eval_during_transform", + # Partition NYI + "TestAutograd.test_topk_grad", + "TestOps.test_argpartition", + "TestOps.test_partition", + # Block masked matmul NYI + "TestBlas.test_block_masked_matmul", + # Gather matmul NYI "TestBlas.test_gather_matmul", "TestBlas.test_gather_matmul_grad", - "TestBlas.test_matmul_batched", - "TestBlas.test_matrix_vector_attn", - "TestCompile.test_compile_dynamic_dims", - "TestCompile.test_compile_inf", - "TestCompile.test_inf_constant", + # Scan NYI + "TestAutograd.test_cumprod_grad", + "TestOps.test_scans", + "TestOps.test_logcumsumexp", + # Hadamard NYI + "TestOps.test_hadamard", + "TestOps.test_hadamard_grad_vmap", + # Convolutions NYI "TestConv.test_1d_conv_with_2d", "TestConv.test_asymmetric_padding", "TestConv.test_basic_grad_shapes", @@ -45,11 +71,11 @@ cuda_skip = { "TestConvTranspose.test_torch_conv_transpose_3D", "TestConvTranspose.test_torch_conv_transpose_3D_grad", "TestConvTranspose.test_torch_conv_transpose_3d_output_padding", - "TestEinsum.test_attention", - "TestEinsum.test_ellipses", - "TestEinsum.test_opt_einsum_test_cases", - "TestEval.test_multi_output_eval_during_transform", "TestExportImport.test_export_conv", + "TestLayers.test_conv1d", + "TestLayers.test_conv2d", + "TestVmap.test_vmap_conv", + # FFTs NYI "TestFFT.test_fft", "TestFFT.test_fft_big_powers_of_two", "TestFFT.test_fft_contiguity", @@ -59,52 +85,22 @@ cuda_skip = { "TestFFT.test_fft_large_numbers", "TestFFT.test_fft_shared_mem", "TestFFT.test_fftn", - "TestInit.test_orthogonal", + # Lapack ops NYI "TestLinalg.test_cholesky", "TestLinalg.test_cholesky_inv", "TestLinalg.test_eig", "TestLinalg.test_eigh", "TestLinalg.test_inverse", + "TestVmap.test_vmap_inverse", "TestLinalg.test_lu", "TestLinalg.test_lu_factor", "TestLinalg.test_pseudo_inverse", "TestLinalg.test_qr_factorization", + "TestInit.test_orthogonal", "TestLinalg.test_svd_decomposition", + "TestVmap.test_vmap_svd", "TestLinalg.test_tri_inverse", - "TestLoad.test_load_f8_e4m3", - "TestLosses.test_binary_cross_entropy", - "TestMemory.test_memory_info", - "TestLayers.test_conv1d", - "TestLayers.test_conv2d", - "TestLayers.test_elu", - "TestLayers.test_group_norm", - "TestLayers.test_hard_shrink", - "TestLayers.test_pooling", - "TestLayers.test_quantized_embedding", - "TestLayers.test_sin_pe", - "TestLayers.test_softshrink", - "TestLayers.test_upsample", - "TestOps.test_argpartition", - "TestOps.test_array_equal", - "TestOps.test_as_strided", - "TestOps.test_binary_ops", - "TestOps.test_bitwise_grad", - "TestOps.test_complex_ops", - "TestOps.test_divmod", - "TestOps.test_dynamic_slicing", - "TestOps.test_hadamard", - "TestOps.test_hadamard_grad_vmap", - "TestOps.test_irregular_binary_ops", - "TestOps.test_kron", - "TestOps.test_log1p", - "TestOps.test_logaddexp", - "TestOps.test_logcumsumexp", - "TestOps.test_partition", - "TestOps.test_scans", - "TestOps.test_softmax", - "TestOps.test_sort", - "TestOps.test_tensordot", - "TestOps.test_tile", + # Quantization NYI "TestQuantized.test_gather_matmul_grad", "TestQuantized.test_gather_qmm", "TestQuantized.test_gather_qmm_sorted", @@ -120,12 +116,4 @@ cuda_skip = { "TestQuantized.test_small_matrix", "TestQuantized.test_throw", "TestQuantized.test_vjp_scales_biases", - "TestReduce.test_axis_permutation_sums", - "TestReduce.test_dtypes", - "TestReduce.test_expand_sums", - "TestReduce.test_many_reduction_axes", - "TestUpsample.test_torch_upsample", - "TestVmap.test_vmap_conv", - "TestVmap.test_vmap_inverse", - "TestVmap.test_vmap_svd", } diff --git a/python/tests/test_losses.py b/python/tests/test_losses.py index cbc657655..2ef1fa36c 100644 --- a/python/tests/test_losses.py +++ b/python/tests/test_losses.py @@ -83,14 +83,14 @@ class TestLosses(mlx_tests.MLXTestCase): logits, targets, reduction="mean" ) expected_mean = mx.mean(expected_none) - self.assertEqual(losses_mean, expected_mean) + self.assertTrue(mx.allclose(losses_mean, expected_mean)) # Test with reduction 'sum' losses_sum = nn.losses.binary_cross_entropy( logits, targets, reduction="sum" ) expected_sum = mx.sum(expected_none) - self.assertEqual(losses_sum, expected_sum) + self.assertTrue(mx.allclose(losses_sum, expected_sum)) # With weights, no label smoothing weights = mx.array([1.0, 2.0, 1.0, 2.0])