This commit is contained in:
Awni Hannun 2025-06-16 12:35:26 -07:00
parent ff1f9ca5e8
commit abcf62ee55
3 changed files with 7 additions and 16 deletions

View File

@ -1,7 +1,7 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/device/utils.cuh"
#include <cuComplex.h> #include <cuComplex.h>
@ -124,11 +124,13 @@ struct LogAddExp {
? maxval ? maxval
: T(float(maxval) + log1p(expf(minval - maxval))); : T(float(maxval) + log1p(expf(minval - maxval)));
}; };
__device__ cuComplex operator()(cuComplex x, cuComplex y) { __device__ cuComplex operator()(cuComplex x, cuComplex y) {
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) || if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) ||
isnan(cuCimagf(y))) { isnan(cuCimagf(y))) {
return {cuda::std::numeric_limits<float>::quiet_NaN(), cuda::std::numeric_limits<float>::quiet_NaN()}; return {
cuda::std::numeric_limits<float>::quiet_NaN(),
cuda::std::numeric_limits<float>::quiet_NaN()};
} }
constexpr float inf = cuda::std::numeric_limits<float>::infinity(); constexpr float inf = cuda::std::numeric_limits<float>::infinity();
auto maxval = x > y ? x : y; auto maxval = x > y ? x : y;

View File

@ -27,9 +27,8 @@ constexpr bool supports_unary_op() {
std::is_same_v<Op, ArcSin> || std::is_same_v<Op, ArcSinh> || std::is_same_v<Op, ArcSin> || std::is_same_v<Op, ArcSinh> ||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> || std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> || std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Expm1> || std::is_same_v<Op, Sigmoid> ||
std::is_same_v<Op, Sigmoid> || std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
std::is_same_v<Op, Rsqrt>) {
return std::is_same_v<In, Out> && is_floating_v<In>; return std::is_same_v<In, Out> && is_floating_v<In>;
} }
if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> || if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||

View File

@ -25,32 +25,25 @@ cuda_skip = {
"TestReduce.test_expand_sums", "TestReduce.test_expand_sums",
"TestReduce.test_many_reduction_axes", "TestReduce.test_many_reduction_axes",
"TestUpsample.test_torch_upsample", "TestUpsample.test_torch_upsample",
# DivMod NYI # DivMod NYI
"TestOps.test_divmod", "TestOps.test_divmod",
"TestEval.test_multi_output_eval_during_transform", "TestEval.test_multi_output_eval_during_transform",
# Partition NYI # Partition NYI
"TestAutograd.test_topk_grad", "TestAutograd.test_topk_grad",
"TestOps.test_argpartition", "TestOps.test_argpartition",
"TestOps.test_partition", "TestOps.test_partition",
# Block masked matmul NYI # Block masked matmul NYI
"TestBlas.test_block_masked_matmul", "TestBlas.test_block_masked_matmul",
# Gather matmul NYI # Gather matmul NYI
"TestBlas.test_gather_matmul", "TestBlas.test_gather_matmul",
"TestBlas.test_gather_matmul_grad", "TestBlas.test_gather_matmul_grad",
# Scan NYI # Scan NYI
"TestAutograd.test_cumprod_grad", "TestAutograd.test_cumprod_grad",
"TestOps.test_scans", "TestOps.test_scans",
"TestOps.test_logcumsumexp", "TestOps.test_logcumsumexp",
# Hadamard NYI # Hadamard NYI
"TestOps.test_hadamard", "TestOps.test_hadamard",
"TestOps.test_hadamard_grad_vmap", "TestOps.test_hadamard_grad_vmap",
# Convolutions NYI # Convolutions NYI
"TestConv.test_1d_conv_with_2d", "TestConv.test_1d_conv_with_2d",
"TestConv.test_asymmetric_padding", "TestConv.test_asymmetric_padding",
@ -82,7 +75,6 @@ cuda_skip = {
"TestLayers.test_conv1d", "TestLayers.test_conv1d",
"TestLayers.test_conv2d", "TestLayers.test_conv2d",
"TestVmap.test_vmap_conv", "TestVmap.test_vmap_conv",
# FFTs NYI # FFTs NYI
"TestFFT.test_fft", "TestFFT.test_fft",
"TestFFT.test_fft_big_powers_of_two", "TestFFT.test_fft_big_powers_of_two",
@ -93,7 +85,6 @@ cuda_skip = {
"TestFFT.test_fft_large_numbers", "TestFFT.test_fft_large_numbers",
"TestFFT.test_fft_shared_mem", "TestFFT.test_fft_shared_mem",
"TestFFT.test_fftn", "TestFFT.test_fftn",
# Lapack ops NYI # Lapack ops NYI
"TestLinalg.test_cholesky", "TestLinalg.test_cholesky",
"TestLinalg.test_cholesky_inv", "TestLinalg.test_cholesky_inv",
@ -109,7 +100,6 @@ cuda_skip = {
"TestLinalg.test_svd_decomposition", "TestLinalg.test_svd_decomposition",
"TestVmap.test_vmap_svd", "TestVmap.test_vmap_svd",
"TestLinalg.test_tri_inverse", "TestLinalg.test_tri_inverse",
# Quantization NYI # Quantization NYI
"TestQuantized.test_gather_matmul_grad", "TestQuantized.test_gather_matmul_grad",
"TestQuantized.test_gather_qmm", "TestQuantized.test_gather_qmm",