diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index d4df06f18..be8fca8d4 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -150,10 +150,10 @@ void binary_op_gpu_inplace( auto [shape, strides] = collapse_contiguous_dims(a, b, out); auto& a_strides = strides[0]; auto& b_strides = strides[1]; - bool large = a.data_size() > UINT32_MAX || - b.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX; + bool large = a.data_size() > INT32_MAX || + b.data_size() > INT32_MAX || out.data_size() > INT32_MAX; MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; + using IdxT = std::conditional_t; int ndim = shape.size(); if (ndim <= 3) { MLX_SWITCH_1_2_3(ndim, NDIM, { diff --git a/mlx/backend/cuda/compiled.cpp b/mlx/backend/cuda/compiled.cpp index a6b8223e0..1aa7ecb92 100644 --- a/mlx/backend/cuda/compiled.cpp +++ b/mlx/backend/cuda/compiled.cpp @@ -130,11 +130,13 @@ struct FusedKernelBuilder { constexpr const char* g_jit_includes = R"( #include "mlx/backend/cuda/device/binary_ops.cuh" +#include "mlx/backend/cuda/device/ternary_ops.cuh" #include "mlx/backend/cuda/device/unary_ops.cuh" #include "mlx/backend/cuda/device/utils.cuh" #include +#define inf cuda::std::numeric_limits::infinity() )"; void Compiled::eval_gpu( diff --git a/mlx/backend/cuda/device/ternary_ops.cuh b/mlx/backend/cuda/device/ternary_ops.cuh index d1d008ac5..441845471 100644 --- a/mlx/backend/cuda/device/ternary_ops.cuh +++ b/mlx/backend/cuda/device/ternary_ops.cuh @@ -1,4 +1,5 @@ // Copyright © 2025 Apple Inc. +#pragma once namespace mlx::core::cu { diff --git a/mlx/backend/cuda/device/utils.cuh b/mlx/backend/cuda/device/utils.cuh index 6f9851c94..d2897203f 100644 --- a/mlx/backend/cuda/device/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -336,4 +336,21 @@ struct LoopedElemToLoc<1, false, OffsetT> { } }; +inline __device__ cuComplex log1p(cuComplex in) { + float x = cuCrealf(in); + float y = cuCimagf(in); + float zabs = sqrt(x * x + y * y); + float theta = atan2f(y, x + 1); + if (zabs < 0.5f) { + float r = x * (2 + x) + y * y; + if (r == 0) { // handle underflow + return {x, theta}; + } + return {0.5f * log1pf(r), theta}; + } else { + auto z0 = sqrt((x + 1) * (x + 1) + y * y); + return {log(z0), theta}; + } +} + } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index 02e46afc1..e33af3c80 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -101,10 +101,10 @@ void ternary_op_gpu_inplace( auto& a_strides = strides[0]; auto& b_strides = strides[1]; auto& c_strides = strides[2]; - bool large = a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX || - c.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX; + bool large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || + c.data_size() > INT32_MAX || out.data_size() > INT32_MAX; MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; + using IdxT = std::conditional_t; int ndim = shape.size(); if (ndim <= 3) { MLX_SWITCH_1_2_3(ndim, NDIM, { diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index d2fa96381..83e206ff9 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -27,13 +27,13 @@ constexpr bool supports_unary_op() { std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { return std::is_same_v && is_floating_v; } if (std::is_same_v || std::is_same_v || - std::is_same_v) { + std::is_same_v || std::is_same_v) { return std::is_same_v && is_inexact_v; } if (std::is_same_v) { diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 0072db192..d3f3e4bda 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -1,24 +1,68 @@ cuda_skip = { "TestArray.test_api", - "TestAutograd.test_cumprod_grad", "TestAutograd.test_slice_grads", "TestAutograd.test_split_against_slice", "TestAutograd.test_stop_gradient", - "TestAutograd.test_topk_grad", "TestAutograd.test_update_state", "TestAutograd.test_vjp", "TestBF16.test_arg_reduction_ops", "TestBF16.test_binary_ops", "TestBF16.test_reduction_ops", - "TestBlas.test_block_masked_matmul", "TestBlas.test_complex_gemm", - "TestBlas.test_gather_matmul", - "TestBlas.test_gather_matmul_grad", "TestBlas.test_matmul_batched", "TestBlas.test_matrix_vector_attn", "TestCompile.test_compile_dynamic_dims", - "TestCompile.test_compile_inf", - "TestCompile.test_inf_constant", + "TestEinsum.test_attention", + "TestEinsum.test_ellipses", + "TestEinsum.test_opt_einsum_test_cases", + "TestEval.test_multi_output_eval_during_transform", + "TestLoad.test_load_f8_e4m3", + "TestLosses.test_binary_cross_entropy", + "TestMemory.test_memory_info", + "TestLayers.test_group_norm", + "TestLayers.test_pooling", + "TestLayers.test_quantized_embedding", + "TestLayers.test_sin_pe", + "TestLayers.test_upsample", + "TestOps.test_array_equal", + "TestOps.test_complex_ops", + "TestOps.test_divmod", + "TestOps.test_dynamic_slicing", + "TestOps.test_irregular_binary_ops", + "TestOps.test_kron", + "TestOps.test_logaddexp", + "TestOps.test_softmax", + "TestOps.test_sort", + "TestOps.test_tensordot", + "TestOps.test_tile", + "TestReduce.test_axis_permutation_sums", + "TestReduce.test_dtypes", + "TestReduce.test_expand_sums", + "TestReduce.test_many_reduction_axes", + "TestUpsample.test_torch_upsample", + + # Partition NYI + "TestAutograd.test_topk_grad", + "TestOps.test_argpartition", + "TestOps.test_partition", + + # Block masked matmul NYI + "TestBlas.test_block_masked_matmul", + + # Gather matmul NYI + "TestBlas.test_gather_matmul", + "TestBlas.test_gather_matmul_grad", + + # Scan NYI + "TestAutograd.test_cumprod_grad", + "TestOps.test_scans", + "TestOps.test_logcumsumexp", + + # Hadamard NYI + "TestOps.test_hadamard", + "TestOps.test_hadamard_grad_vmap", + + # Convolutions NYI "TestConv.test_1d_conv_with_2d", "TestConv.test_asymmetric_padding", "TestConv.test_basic_grad_shapes", @@ -45,11 +89,12 @@ cuda_skip = { "TestConvTranspose.test_torch_conv_transpose_3D", "TestConvTranspose.test_torch_conv_transpose_3D_grad", "TestConvTranspose.test_torch_conv_transpose_3d_output_padding", - "TestEinsum.test_attention", - "TestEinsum.test_ellipses", - "TestEinsum.test_opt_einsum_test_cases", - "TestEval.test_multi_output_eval_during_transform", "TestExportImport.test_export_conv", + "TestLayers.test_conv1d", + "TestLayers.test_conv2d", + "TestVmap.test_vmap_conv", + + # FFTs NYI "TestFFT.test_fft", "TestFFT.test_fft_big_powers_of_two", "TestFFT.test_fft_contiguity", @@ -59,52 +104,24 @@ cuda_skip = { "TestFFT.test_fft_large_numbers", "TestFFT.test_fft_shared_mem", "TestFFT.test_fftn", - "TestInit.test_orthogonal", + + # Lapack ops NYI "TestLinalg.test_cholesky", "TestLinalg.test_cholesky_inv", "TestLinalg.test_eig", "TestLinalg.test_eigh", "TestLinalg.test_inverse", + "TestVmap.test_vmap_inverse", "TestLinalg.test_lu", "TestLinalg.test_lu_factor", "TestLinalg.test_pseudo_inverse", "TestLinalg.test_qr_factorization", + "TestInit.test_orthogonal", "TestLinalg.test_svd_decomposition", + "TestVmap.test_vmap_svd", "TestLinalg.test_tri_inverse", - "TestLoad.test_load_f8_e4m3", - "TestLosses.test_binary_cross_entropy", - "TestMemory.test_memory_info", - "TestLayers.test_conv1d", - "TestLayers.test_conv2d", - "TestLayers.test_elu", - "TestLayers.test_group_norm", - "TestLayers.test_hard_shrink", - "TestLayers.test_pooling", - "TestLayers.test_quantized_embedding", - "TestLayers.test_sin_pe", - "TestLayers.test_softshrink", - "TestLayers.test_upsample", - "TestOps.test_argpartition", - "TestOps.test_array_equal", - "TestOps.test_as_strided", - "TestOps.test_binary_ops", - "TestOps.test_bitwise_grad", - "TestOps.test_complex_ops", - "TestOps.test_divmod", - "TestOps.test_dynamic_slicing", - "TestOps.test_hadamard", - "TestOps.test_hadamard_grad_vmap", - "TestOps.test_irregular_binary_ops", - "TestOps.test_kron", - "TestOps.test_log1p", - "TestOps.test_logaddexp", - "TestOps.test_logcumsumexp", - "TestOps.test_partition", - "TestOps.test_scans", - "TestOps.test_softmax", - "TestOps.test_sort", - "TestOps.test_tensordot", - "TestOps.test_tile", + + # Quantization NYI "TestQuantized.test_gather_matmul_grad", "TestQuantized.test_gather_qmm", "TestQuantized.test_gather_qmm_sorted", @@ -120,12 +137,4 @@ cuda_skip = { "TestQuantized.test_small_matrix", "TestQuantized.test_throw", "TestQuantized.test_vjp_scales_biases", - "TestReduce.test_axis_permutation_sums", - "TestReduce.test_dtypes", - "TestReduce.test_expand_sums", - "TestReduce.test_many_reduction_axes", - "TestUpsample.test_torch_upsample", - "TestVmap.test_vmap_conv", - "TestVmap.test_vmap_inverse", - "TestVmap.test_vmap_svd", }