Enable tests

This commit is contained in:
Cheng
2025-07-09 11:15:10 +00:00
parent b89d8ef1c0
commit f797b1b3e5
2 changed files with 6 additions and 8 deletions

View File

@@ -266,7 +266,6 @@ __global__ void strided_scan(
for (int i = 0; i < n_scans; ++i) { for (int i = 0; i < n_scans; ++i) {
values[i] = read_from[i]; values[i] = read_from[i];
} }
warp.sync();
// Perform the scan. // Perform the scan.
for (int i = 0; i < n_scans; ++i) { for (int i = 0; i < n_scans; ++i) {
@@ -436,8 +435,12 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
int64_t stride_blocks = cuda::ceil_div(stride, BN); int64_t stride_blocks = cuda::ceil_div(stride, BN);
dim3 num_blocks = get_2d_grid_dims( dim3 num_blocks = get_2d_grid_dims(
in.shape(), in.strides(), axis_size * stride); in.shape(), in.strides(), axis_size * stride);
num_blocks.x *= stride_blocks; if (num_blocks.x * stride_blocks <= UINT32_MAX) {
int block_dim = BN / N_READS * WARP_SIZE; num_blocks.x *= stride_blocks;
} else {
num_blocks.y *= stride_blocks;
}
int block_dim = (BN / N_READS) * WARP_SIZE;
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
num_blocks, num_blocks,

View File

@@ -13,11 +13,6 @@ cuda_skip = {
"TestBlas.test_gather_mm_sorted", "TestBlas.test_gather_mm_sorted",
# Segmented matmul NYI # Segmented matmul NYI
"TestBlas.test_segmented_mm", "TestBlas.test_segmented_mm",
# Scan NYI
"TestArray.test_api",
"TestAutograd.test_cumprod_grad",
"TestOps.test_scans",
"TestOps.test_logcumsumexp",
# Hadamard NYI # Hadamard NYI
"TestOps.test_hadamard", "TestOps.test_hadamard",
"TestOps.test_hadamard_grad_vmap", "TestOps.test_hadamard_grad_vmap",