From abcf62ee55ab959b6613437d4ca8b8b9ee50595a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 16 Jun 2025 12:35:26 -0700 Subject: [PATCH] format --- mlx/backend/cuda/device/binary_ops.cuh | 8 +++++--- mlx/backend/cuda/unary.cu | 5 ++--- python/tests/cuda_skip.py | 10 ---------- 3 files changed, 7 insertions(+), 16 deletions(-) diff --git a/mlx/backend/cuda/device/binary_ops.cuh b/mlx/backend/cuda/device/binary_ops.cuh index d847200d4..ca5ac35e6 100644 --- a/mlx/backend/cuda/device/binary_ops.cuh +++ b/mlx/backend/cuda/device/binary_ops.cuh @@ -1,7 +1,7 @@ // Copyright © 2025 Apple Inc. -#include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/device/cucomplex_math.cuh" +#include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/device/utils.cuh" #include @@ -124,11 +124,13 @@ struct LogAddExp { ? maxval : T(float(maxval) + log1p(expf(minval - maxval))); }; - + __device__ cuComplex operator()(cuComplex x, cuComplex y) { if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) || isnan(cuCimagf(y))) { - return {cuda::std::numeric_limits::quiet_NaN(), cuda::std::numeric_limits::quiet_NaN()}; + return { + cuda::std::numeric_limits::quiet_NaN(), + cuda::std::numeric_limits::quiet_NaN()}; } constexpr float inf = cuda::std::numeric_limits::infinity(); auto maxval = x > y ? x : y; diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index 83e206ff9..e45144eda 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -27,9 +27,8 @@ constexpr bool supports_unary_op() { std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v) { + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { return std::is_same_v && is_floating_v; } if (std::is_same_v || std::is_same_v || diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index cd09de0c4..23c5fb19c 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -25,32 +25,25 @@ cuda_skip = { "TestReduce.test_expand_sums", "TestReduce.test_many_reduction_axes", "TestUpsample.test_torch_upsample", - # DivMod NYI "TestOps.test_divmod", "TestEval.test_multi_output_eval_during_transform", - # Partition NYI "TestAutograd.test_topk_grad", "TestOps.test_argpartition", "TestOps.test_partition", - # Block masked matmul NYI "TestBlas.test_block_masked_matmul", - # Gather matmul NYI "TestBlas.test_gather_matmul", "TestBlas.test_gather_matmul_grad", - # Scan NYI "TestAutograd.test_cumprod_grad", "TestOps.test_scans", "TestOps.test_logcumsumexp", - # Hadamard NYI "TestOps.test_hadamard", "TestOps.test_hadamard_grad_vmap", - # Convolutions NYI "TestConv.test_1d_conv_with_2d", "TestConv.test_asymmetric_padding", @@ -82,7 +75,6 @@ cuda_skip = { "TestLayers.test_conv1d", "TestLayers.test_conv2d", "TestVmap.test_vmap_conv", - # FFTs NYI "TestFFT.test_fft", "TestFFT.test_fft_big_powers_of_two", @@ -93,7 +85,6 @@ cuda_skip = { "TestFFT.test_fft_large_numbers", "TestFFT.test_fft_shared_mem", "TestFFT.test_fftn", - # Lapack ops NYI "TestLinalg.test_cholesky", "TestLinalg.test_cholesky_inv", @@ -109,7 +100,6 @@ cuda_skip = { "TestLinalg.test_svd_decomposition", "TestVmap.test_vmap_svd", "TestLinalg.test_tri_inverse", - # Quantization NYI "TestQuantized.test_gather_matmul_grad", "TestQuantized.test_gather_qmm",