From f797b1b3e5ed6b1fc6561f2fd927aef81a8f7f05 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 9 Jul 2025 11:15:10 +0000 Subject: [PATCH] Enable tests --- mlx/backend/cuda/scan.cu | 9 ++++++--- python/tests/cuda_skip.py | 5 ----- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/mlx/backend/cuda/scan.cu b/mlx/backend/cuda/scan.cu index 198bf2dbb..7a26ee161 100644 --- a/mlx/backend/cuda/scan.cu +++ b/mlx/backend/cuda/scan.cu @@ -266,7 +266,6 @@ __global__ void strided_scan( for (int i = 0; i < n_scans; ++i) { values[i] = read_from[i]; } - warp.sync(); // Perform the scan. for (int i = 0; i < n_scans; ++i) { @@ -436,8 +435,12 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { int64_t stride_blocks = cuda::ceil_div(stride, BN); dim3 num_blocks = get_2d_grid_dims( in.shape(), in.strides(), axis_size * stride); - num_blocks.x *= stride_blocks; - int block_dim = BN / N_READS * WARP_SIZE; + if (num_blocks.x * stride_blocks <= UINT32_MAX) { + num_blocks.x *= stride_blocks; + } else { + num_blocks.y *= stride_blocks; + } + int block_dim = (BN / N_READS) * WARP_SIZE; encoder.add_kernel_node( kernel, num_blocks, diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index afd48bd03..005c612ff 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -13,11 +13,6 @@ cuda_skip = { "TestBlas.test_gather_mm_sorted", # Segmented matmul NYI "TestBlas.test_segmented_mm", - # Scan NYI - "TestArray.test_api", - "TestAutograd.test_cumprod_grad", - "TestOps.test_scans", - "TestOps.test_logcumsumexp", # Hadamard NYI "TestOps.test_hadamard", "TestOps.test_hadamard_grad_vmap",