diff --git a/mlx/backend/cuda/device/complex.cuh b/mlx/backend/cuda/device/complex.cuh index 8dfd23b46..03a7bff83 100644 --- a/mlx/backend/cuda/device/complex.cuh +++ b/mlx/backend/cuda/device/complex.cuh @@ -38,14 +38,13 @@ inline __host__ __device__ complex_t operator%( } template -inline __host__ __device__ bool operator<(complex_t a, complex_t b) { - return (a.real() * a.real() + a.imag() * a.imag()) < - (b.real() * b.real() + b.imag() * b.imag()); +inline __host__ __device__ bool operator>(complex_t a, complex_t b) { + return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag()); } template -inline __host__ __device__ bool operator>(complex_t a, complex_t b) { - return b < a; +inline __host__ __device__ bool operator<(complex_t a, complex_t b) { + return operator>(b, a); } template diff --git a/mlx/backend/cuda/reduce/reduce_ops.cuh b/mlx/backend/cuda/reduce/reduce_ops.cuh index 31ba90433..7f8cad0c4 100644 --- a/mlx/backend/cuda/reduce/reduce_ops.cuh +++ b/mlx/backend/cuda/reduce/reduce_ops.cuh @@ -69,6 +69,18 @@ struct Prod { struct Min { template __device__ __forceinline__ T operator()(T a, T b) { + if constexpr (is_complex_v) { + if (isnan(a.real()) || isnan(a.imag())) { + return a; + } + if (isnan(b.real()) || isnan(b.imag())) { + return b; + } + } else if constexpr (!cuda::std::is_integral_v) { + if (isnan(a) || isnan(b)) { + return cuda::std::numeric_limits::quiet_NaN(); + } + } return a < b ? a : b; } @@ -81,6 +93,18 @@ struct Min { struct Max { template __device__ __forceinline__ T operator()(T a, T b) { + if constexpr (is_complex_v) { + if (isnan(a.real()) || isnan(a.imag())) { + return a; + } + if (isnan(b.real()) || isnan(b.imag())) { + return b; + } + } else if constexpr (!cuda::std::is_integral_v) { + if (isnan(a) || isnan(b)) { + return cuda::std::numeric_limits::quiet_NaN(); + } + } return a > b ? a : b; } diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 7c9ff84ce..50cb8dcbe 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -2,9 +2,6 @@ cuda_skip = { "TestLoad.test_load_f8_e4m3", "TestLayers.test_quantized_embedding", "TestOps.test_dynamic_slicing", - "TestReduce.test_dtypes", - "TestReduce.test_nanpropagation", - "TestReduce.test_nanpropagation_complex64", # Block masked matmul NYI "TestBlas.test_block_masked_matmul", # Gather matmul NYI diff --git a/python/tests/test_reduce.py b/python/tests/test_reduce.py index 9efd6c5c7..d6ddf353b 100644 --- a/python/tests/test_reduce.py +++ b/python/tests/test_reduce.py @@ -153,7 +153,7 @@ class TestReduce(mlx_tests.MLXTestCase): x = x.transpose(1, 0, 2, 3, 4, 5, 6, 7, 8, 9) check(x, (1, 3, 5, 7, 9)) - def test_nanpropagation(self): + def test_nan_propagation(self): dtypes = [ "uint8", "uint16", @@ -179,7 +179,7 @@ class TestReduce(mlx_tests.MLXTestCase): ref = getattr(np, op)(x_np, axis=axis) self.assertTrue(np.array_equal(out, ref, equal_nan=True)) - def test_nanpropagation_complex64(self): + def test_nan_propagation_complex64(self): complex_array_1 = mx.array( [1 + 1j, 2 + 2j, 3 + 3j, mx.nan + 4j], dtype=mx.complex64 ).reshape(2, 2)