Enable tests

This commit is contained in:
Cheng
2025-07-09 11:15:10 +00:00
parent b89d8ef1c0
commit f797b1b3e5
2 changed files with 6 additions and 8 deletions

View File

@@ -266,7 +266,6 @@ __global__ void strided_scan(
for (int i = 0; i < n_scans; ++i) {
values[i] = read_from[i];
}
warp.sync();
// Perform the scan.
for (int i = 0; i < n_scans; ++i) {
@@ -436,8 +435,12 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
int64_t stride_blocks = cuda::ceil_div(stride, BN);
dim3 num_blocks = get_2d_grid_dims(
in.shape(), in.strides(), axis_size * stride);
num_blocks.x *= stride_blocks;
int block_dim = BN / N_READS * WARP_SIZE;
if (num_blocks.x * stride_blocks <= UINT32_MAX) {
num_blocks.x *= stride_blocks;
} else {
num_blocks.y *= stride_blocks;
}
int block_dim = (BN / N_READS) * WARP_SIZE;
encoder.add_kernel_node(
kernel,
num_blocks,

View File

@@ -13,11 +13,6 @@ cuda_skip = {
"TestBlas.test_gather_mm_sorted",
# Segmented matmul NYI
"TestBlas.test_segmented_mm",
# Scan NYI
"TestArray.test_api",
"TestAutograd.test_cumprod_grad",
"TestOps.test_scans",
"TestOps.test_logcumsumexp",
# Hadamard NYI
"TestOps.test_hadamard",
"TestOps.test_hadamard_grad_vmap",