From dd4f53db63020ede8b8abf6eec91a35e92dc73c1 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 1 Jul 2025 07:30:00 -0700 Subject: [PATCH] use fp32 for testing, add more complex ops (#2322) --- mlx/backend/cuda/device/unary_ops.cuh | 54 +++++++++++++++++++++++---- mlx/backend/cuda/layer_norm.cu | 2 - mlx/backend/cuda/rms_norm.cu | 1 - mlx/backend/cuda/unary.cu | 35 ++++++++--------- python/tests/cuda_skip.py | 12 +----- python/tests/mlx_tests.py | 4 ++ 6 files changed, 68 insertions(+), 40 deletions(-) diff --git a/mlx/backend/cuda/device/unary_ops.cuh b/mlx/backend/cuda/device/unary_ops.cuh index efa9133b1..18d769c2a 100644 --- a/mlx/backend/cuda/device/unary_ops.cuh +++ b/mlx/backend/cuda/device/unary_ops.cuh @@ -27,6 +27,8 @@ struct ArcCos { __device__ T operator()(T x) { return acos(x); } + + __device__ cuComplex operator()(cuComplex x); }; struct ArcCosh { @@ -41,6 +43,8 @@ struct ArcSin { __device__ T operator()(T x) { return asin(x); } + + __device__ cuComplex operator()(cuComplex x); }; struct ArcSinh { @@ -55,6 +59,8 @@ struct ArcTan { __device__ T operator()(T x) { return atan(x); } + + __device__ cuComplex operator()(cuComplex x); }; struct ArcTanh { @@ -261,13 +267,6 @@ struct Round { } }; -struct Rsqrt { - template - __device__ T operator()(T x) { - return rsqrt(x); - } -}; - struct Sigmoid { template __device__ T operator()(T x) { @@ -333,6 +332,29 @@ struct Sqrt { __device__ T operator()(T x) { return sqrt(x); } + + __device__ cuComplex operator()(cuComplex x) { + auto xr = cuCrealf(x); + auto xi = cuCimagf(x); + if (xr == 0.0f && xi == 0.0f) { + return {0.0f, 0.0f}; + } + auto r = cuCrealf(Abs{}(x)); + auto a = sqrt((r + xr) / 2.0f); + auto b_abs = sqrt((r - xr) / 2.0f); + auto b = copysign(b_abs, xi); + return {a, b}; + } +}; + +struct Rsqrt { + template + __device__ T operator()(T x) { + return rsqrt(x); + } + __device__ cuComplex operator()(cuComplex x) { + return 1.0f / Sqrt{}(x); + } }; struct Tan { @@ -365,4 +387,22 @@ struct Tanh { } }; +__device__ cuComplex ArcCos::operator()(cuComplex x) { + auto i = cuComplex{0.0, 1.0}; + auto y = Log{}(x + i * Sqrt{}(1.0 - x * x)); + return {cuCimagf(y), -cuCrealf(y)}; +}; + +__device__ cuComplex ArcSin::operator()(cuComplex x) { + auto i = cuComplex{0.0f, 1.0f}; + auto y = Log{}(i * x + Sqrt{}(1.0f - x * x)); + return {cuCimagf(y), -cuCrealf(y)}; +}; + +__device__ cuComplex ArcTan::operator()(cuComplex x) { + auto i = cuComplex{0.0f, 1.0f}; + auto ix = i * x; + return (1.0f / cuComplex{0.0f, 2.0f}) * Log{}((1.0f + ix) / (1.0f - ix)); +}; + } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/layer_norm.cu b/mlx/backend/cuda/layer_norm.cu index 23f0b168f..852cf43af 100644 --- a/mlx/backend/cuda/layer_norm.cu +++ b/mlx/backend/cuda/layer_norm.cu @@ -342,8 +342,6 @@ void LayerNormVJP::eval_gpu( encoder.add_temporary(gw_temp); } } - gw.set_data(allocator::malloc(gw.nbytes())); - gb.set_data(allocator::malloc(gb.nbytes())); // Finish with the gradient for b in case we had a b. if (gb.ndim() == 1 && gb.size() == axis_size) { diff --git a/mlx/backend/cuda/rms_norm.cu b/mlx/backend/cuda/rms_norm.cu index 7b87f2947..7f5f9630d 100644 --- a/mlx/backend/cuda/rms_norm.cu +++ b/mlx/backend/cuda/rms_norm.cu @@ -304,7 +304,6 @@ void RMSNormVJP::eval_gpu( encoder.add_temporary(gw_temp); } } - gw.set_data(allocator::malloc(gw.nbytes())); encoder.set_input_array(x); encoder.set_input_array(w); diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index 4f9bac29f..74251d1f6 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -20,38 +20,35 @@ namespace cu { template constexpr bool supports_unary_op() { if (std::is_same_v || std::is_same_v || - std::is_same_v) { + std::is_same_v || std::is_same_v) { return std::is_same_v; } - if (std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) { + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v) { return std::is_same_v && is_floating_v; } - if (std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) { - return std::is_same_v && is_inexact_v; - } if (std::is_same_v) { return std::is_same_v && std::is_integral_v && !std::is_same_v; } - if (std::is_same_v || std::is_same_v || - std::is_same_v) { + if (std::is_same_v || std::is_same_v) { return std::is_same_v && !std::is_same_v; } if (std::is_same_v) { return std::is_same_v && std::is_same_v; } - if (std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) { - return std::is_same_v && - (is_floating_v || std::is_same_v); + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v) { + return std::is_same_v && is_inexact_v; } if (std::is_same_v || std::is_same_v) { return std::is_same_v && std::is_same_v; diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index cba642ca1..fce92bacb 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -1,25 +1,15 @@ cuda_skip = { - "TestArray.test_api", - "TestBF16.test_arg_reduction_ops", - "TestBlas.test_complex_gemm", - "TestEinsum.test_ellipses", - "TestEinsum.test_opt_einsum_test_cases", "TestLoad.test_load_f8_e4m3", - "TestLayers.test_group_norm", - "TestLayers.test_pooling", "TestLayers.test_quantized_embedding", - "TestLayers.test_sin_pe", - "TestLayers.test_upsample", - "TestOps.test_complex_ops", "TestOps.test_dynamic_slicing", "TestReduce.test_dtypes", - "TestUpsample.test_torch_upsample", # Block masked matmul NYI "TestBlas.test_block_masked_matmul", # Gather matmul NYI "TestBlas.test_gather_matmul", "TestBlas.test_gather_matmul_grad", # Scan NYI + "TestArray.test_api", "TestAutograd.test_cumprod_grad", "TestOps.test_scans", "TestOps.test_logcumsumexp", diff --git a/python/tests/mlx_tests.py b/python/tests/mlx_tests.py index 65bd0e873..bc197b673 100644 --- a/python/tests/mlx_tests.py +++ b/python/tests/mlx_tests.py @@ -1,6 +1,10 @@ # Copyright © 2023 Apple Inc. import os + +# Use regular fp32 precision for tests +os.environ["MLX_ENABLE_TF32"] = "0" + import platform import unittest from typing import Any, Callable, List, Tuple, Union