mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Enable tests
This commit is contained in:
@@ -266,7 +266,6 @@ __global__ void strided_scan(
|
||||
for (int i = 0; i < n_scans; ++i) {
|
||||
values[i] = read_from[i];
|
||||
}
|
||||
warp.sync();
|
||||
|
||||
// Perform the scan.
|
||||
for (int i = 0; i < n_scans; ++i) {
|
||||
@@ -436,8 +435,12 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int64_t stride_blocks = cuda::ceil_div(stride, BN);
|
||||
dim3 num_blocks = get_2d_grid_dims(
|
||||
in.shape(), in.strides(), axis_size * stride);
|
||||
num_blocks.x *= stride_blocks;
|
||||
int block_dim = BN / N_READS * WARP_SIZE;
|
||||
if (num_blocks.x * stride_blocks <= UINT32_MAX) {
|
||||
num_blocks.x *= stride_blocks;
|
||||
} else {
|
||||
num_blocks.y *= stride_blocks;
|
||||
}
|
||||
int block_dim = (BN / N_READS) * WARP_SIZE;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
|
||||
@@ -13,11 +13,6 @@ cuda_skip = {
|
||||
"TestBlas.test_gather_mm_sorted",
|
||||
# Segmented matmul NYI
|
||||
"TestBlas.test_segmented_mm",
|
||||
# Scan NYI
|
||||
"TestArray.test_api",
|
||||
"TestAutograd.test_cumprod_grad",
|
||||
"TestOps.test_scans",
|
||||
"TestOps.test_logcumsumexp",
|
||||
# Hadamard NYI
|
||||
"TestOps.test_hadamard",
|
||||
"TestOps.test_hadamard_grad_vmap",
|
||||
|
||||
Reference in New Issue
Block a user