diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index be8fca8d4..e8e8a8988 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -101,10 +101,12 @@ constexpr bool supports_binary_op() { return std::is_same_v && std::is_same_v; } if (std::is_same_v) { - return std::is_same_v && - (is_floating_v || std::is_same_v); + return std::is_same_v && is_inexact_v; } - if (std::is_same_v || std::is_same_v) { + if (std::is_same_v) { + return std::is_same_v && is_inexact_v; + } + if (std::is_same_v) { return std::is_same_v && is_floating_v; } if (std::is_same_v || std::is_same_v || diff --git a/mlx/backend/cuda/device/binary_ops.cuh b/mlx/backend/cuda/device/binary_ops.cuh index b96a7f9cc..d847200d4 100644 --- a/mlx/backend/cuda/device/binary_ops.cuh +++ b/mlx/backend/cuda/device/binary_ops.cuh @@ -1,6 +1,8 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device/fp16_math.cuh" +#include "mlx/backend/cuda/device/cucomplex_math.cuh" +#include "mlx/backend/cuda/device/utils.cuh" #include #include @@ -122,6 +124,24 @@ struct LogAddExp { ? maxval : T(float(maxval) + log1p(expf(minval - maxval))); }; + + __device__ cuComplex operator()(cuComplex x, cuComplex y) { + if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) || + isnan(cuCimagf(y))) { + return {cuda::std::numeric_limits::quiet_NaN(), cuda::std::numeric_limits::quiet_NaN()}; + } + constexpr float inf = cuda::std::numeric_limits::infinity(); + auto maxval = x > y ? x : y; + auto minval = x < y ? x : y; + if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf) + return maxval; + float m = exp(cuCrealf(minval) - cuCrealf(maxval)); + cuComplex dexp{ + m * cos(cuCimagf(minval) - cuCimagf(maxval)), + m * sin(cuCimagf(minval) - cuCimagf(maxval)), + }; + return maxval + log1p(dexp); + } }; struct Maximum { diff --git a/mlx/backend/cuda/device/utils.cuh b/mlx/backend/cuda/device/utils.cuh index d2897203f..54d551992 100644 --- a/mlx/backend/cuda/device/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -187,8 +187,8 @@ inline __host__ __device__ cuda::std::tuple elem_to_loc_nd( template inline __host__ __device__ IdxT elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) { - IdxT loc = elem_to_loc_nd<3>(elem, shape, strides); - for (int i = ndim - 1; i >= 3; --i) { + IdxT loc = 0; + for (int i = ndim - 1; i >= 0; --i) { loc += (elem % shape[i]) * IdxT(strides[i]); elem /= shape[i]; } @@ -202,8 +202,9 @@ inline __host__ __device__ cuda::std::tuple elem_to_loc_4d( const int64_t* a_strides, const int64_t* b_strides, int ndim) { - auto [a_loc, b_loc] = elem_to_loc_nd<3>(elem, shape, a_strides, b_strides); - for (int i = ndim - 1; i >= 3; --i) { + IdxT a_loc = 0; + IdxT b_loc = 0; + for (int i = ndim - 1; i >= 0; --i) { int dim_idx = elem % shape[i]; a_loc += dim_idx * a_strides[i]; b_loc += dim_idx * b_strides[i]; @@ -220,9 +221,10 @@ inline __host__ __device__ cuda::std::tuple elem_to_loc_4d( const int64_t* b_strides, const int64_t* c_strides, int ndim) { - auto [a_loc, b_loc, c_loc] = - elem_to_loc_nd<3>(elem, shape, a_strides, b_strides, c_strides); - for (int i = ndim - 1; i >= 3; --i) { + IdxT a_loc = 0; + IdxT b_loc = 0; + IdxT c_loc = 0; + for (int i = ndim - 1; i >= 0; --i) { int dim_idx = elem % shape[i]; a_loc += dim_idx * a_strides[i]; b_loc += dim_idx * b_strides[i]; diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp index 3603605c4..65a175fbd 100644 --- a/mlx/backend/cuda/indexing.cpp +++ b/mlx/backend/cuda/indexing.cpp @@ -65,8 +65,8 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32; int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0; - bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) || - (src.size() > UINT32_MAX) || (out.size() > UINT32_MAX); + bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) || + (src.size() > INT32_MAX) || (out.size() > INT32_MAX); uint32_t slice_size = std::accumulate( slice_sizes_.begin(), slice_sizes_.end(), 1, std::multiplies()); @@ -88,7 +88,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { dtype_to_cuda_type(idx_dtype), nidx, ndim, - large ? "int64_t" : "uint32_t")); + large ? "int64_t" : "int32_t")); } } return std::make_pair(jit_source_gather, std::move(kernel_names)); @@ -99,7 +99,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { if (large) { mod.append_arg(out.size()); } else { - mod.append_arg(out.size()); + mod.append_arg(out.size()); } mod.append_ndim_arg(src.shape()); mod.append_ndim_arg(src.strides()); @@ -115,7 +115,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { dtype_to_cuda_type(idx_dtype), nidx, idx_ndim, - large ? "int64_t" : "uint32_t"); + large ? "int64_t" : "int32_t"); auto& encoder = cu::get_command_encoder(s); for (const auto& in : inputs) { @@ -152,14 +152,14 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32; int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0; - bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) || - (upd.size() > UINT32_MAX) || (out.size() > UINT32_MAX); + bool large = (nidx > 0 && inputs[1].size() > INT32_MAX) || + (upd.size() > INT32_MAX) || (out.size() > INT32_MAX); - uint32_t upd_post_idx_size = std::accumulate( + int32_t upd_post_idx_size = std::accumulate( upd.shape().begin() + idx_ndim, upd.shape().end(), 1, - std::multiplies()); + std::multiplies()); const char* op = g_scatter_ops[reduce_type_]; std::string module_name = fmt::format( @@ -181,7 +181,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { op, nidx, ndim, - large ? "int64_t" : "uint32_t")); + large ? "int64_t" : "int32_t")); } } return std::make_pair(jit_source_scatter, std::move(kernel_names)); @@ -192,7 +192,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { if (large) { mod.append_arg(upd.size()); } else { - mod.append_arg(upd.size()); + mod.append_arg(upd.size()); } mod.append_ndim_arg(upd.shape()); mod.append_ndim_arg(upd.strides()); @@ -200,7 +200,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { if (large) { mod.append_arg(upd_post_idx_size); } else { - mod.append_arg(upd_post_idx_size); + mod.append_arg(upd_post_idx_size); } mod.append_ndim_arg(out.shape()); mod.append_ndim_arg(out.strides()); @@ -215,7 +215,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { op, nidx, idx_ndim, - large ? "int64_t" : "uint32_t"); + large ? "int64_t" : "int32_t"); auto& encoder = cu::get_command_encoder(s); for (const auto& in : inputs) { @@ -238,7 +238,7 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { return; } - bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX; + bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX; std::string module_name = fmt::format( "gather_axis_{}_{}", @@ -258,7 +258,7 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { ndim, contiguous & 1 ? true : false, contiguous & 2 ? true : false, - large ? "int64_t" : "uint32_t")); + large ? "int64_t" : "int32_t")); } } } @@ -283,9 +283,9 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { mod.append_arg(idx_size_axis); mod.append_arg(idx_size_post); } else { - mod.append_arg(idx_size_pre); - mod.append_arg(idx_size_axis); - mod.append_arg(idx_size_post); + mod.append_arg(idx_size_pre); + mod.append_arg(idx_size_axis); + mod.append_arg(idx_size_post); } mod.append_arg(remove_index(idx.shape(), axis_)); mod.append_arg(remove_index(src.strides(), axis_)); @@ -302,7 +302,7 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { src.ndim() - 1, src.flags().row_contiguous, idx.flags().row_contiguous, - large ? "int64_t" : "uint32_t"); + large ? "int64_t" : "int32_t"); auto& encoder = cu::get_command_encoder(s); for (const auto& in : inputs) { @@ -337,7 +337,7 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { return; } - bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX; + bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX; const char* op = reduce_type_ == ScatterAxis::Sum ? "Sum" : "Assign"; std::string module_name = fmt::format( @@ -360,7 +360,7 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { ndim, contiguous & 1 ? true : false, contiguous & 2 ? true : false, - large ? "int64_t" : "uint32_t")); + large ? "int64_t" : "int32_t")); } } } @@ -385,9 +385,9 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { mod.append_arg(idx_size_axis); mod.append_arg(idx_size_post); } else { - mod.append_arg(idx_size_pre); - mod.append_arg(idx_size_axis); - mod.append_arg(idx_size_post); + mod.append_arg(idx_size_pre); + mod.append_arg(idx_size_axis); + mod.append_arg(idx_size_post); } mod.append_arg(remove_index(idx.shape(), axis_)); mod.append_arg(remove_index(upd.strides(), axis_)); @@ -405,7 +405,7 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { idx.ndim() - 1, upd.flags().row_contiguous, idx.flags().row_contiguous, - large ? "int64_t" : "uint32_t"); + large ? "int64_t" : "int32_t"); auto& encoder = cu::get_command_encoder(s); for (const auto& in : inputs) { diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp index 2f5e2a4c8..4a3d8be30 100644 --- a/mlx/backend/cuda/utils.cpp +++ b/mlx/backend/cuda/utils.cpp @@ -31,6 +31,9 @@ const char* dtype_to_cuda_type(const Dtype& dtype) { if (dtype == bfloat16) { return "__nv_bfloat16"; } + if (dtype == complex64) { + return "cuComplex"; + } #define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \ if (dtype == DTYPE) { \ return #CPP_TYPE; \ diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index d3f3e4bda..cd09de0c4 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -1,23 +1,13 @@ cuda_skip = { "TestArray.test_api", - "TestAutograd.test_slice_grads", - "TestAutograd.test_split_against_slice", - "TestAutograd.test_stop_gradient", "TestAutograd.test_update_state", - "TestAutograd.test_vjp", "TestBF16.test_arg_reduction_ops", - "TestBF16.test_binary_ops", "TestBF16.test_reduction_ops", "TestBlas.test_complex_gemm", - "TestBlas.test_matmul_batched", - "TestBlas.test_matrix_vector_attn", "TestCompile.test_compile_dynamic_dims", - "TestEinsum.test_attention", "TestEinsum.test_ellipses", "TestEinsum.test_opt_einsum_test_cases", - "TestEval.test_multi_output_eval_during_transform", "TestLoad.test_load_f8_e4m3", - "TestLosses.test_binary_cross_entropy", "TestMemory.test_memory_info", "TestLayers.test_group_norm", "TestLayers.test_pooling", @@ -26,14 +16,9 @@ cuda_skip = { "TestLayers.test_upsample", "TestOps.test_array_equal", "TestOps.test_complex_ops", - "TestOps.test_divmod", "TestOps.test_dynamic_slicing", - "TestOps.test_irregular_binary_ops", - "TestOps.test_kron", - "TestOps.test_logaddexp", "TestOps.test_softmax", "TestOps.test_sort", - "TestOps.test_tensordot", "TestOps.test_tile", "TestReduce.test_axis_permutation_sums", "TestReduce.test_dtypes", @@ -41,6 +26,10 @@ cuda_skip = { "TestReduce.test_many_reduction_axes", "TestUpsample.test_torch_upsample", + # DivMod NYI + "TestOps.test_divmod", + "TestEval.test_multi_output_eval_during_transform", + # Partition NYI "TestAutograd.test_topk_grad", "TestOps.test_argpartition", diff --git a/python/tests/test_losses.py b/python/tests/test_losses.py index cbc657655..2ef1fa36c 100644 --- a/python/tests/test_losses.py +++ b/python/tests/test_losses.py @@ -83,14 +83,14 @@ class TestLosses(mlx_tests.MLXTestCase): logits, targets, reduction="mean" ) expected_mean = mx.mean(expected_none) - self.assertEqual(losses_mean, expected_mean) + self.assertTrue(mx.allclose(losses_mean, expected_mean)) # Test with reduction 'sum' losses_sum = nn.losses.binary_cross_entropy( logits, targets, reduction="sum" ) expected_sum = mx.sum(expected_none) - self.assertEqual(losses_sum, expected_sum) + self.assertTrue(mx.allclose(losses_sum, expected_sum)) # With weights, no label smoothing weights = mx.array([1.0, 2.0, 1.0, 2.0])