mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Enable tests
This commit is contained in:
@@ -266,7 +266,6 @@ __global__ void strided_scan(
|
|||||||
for (int i = 0; i < n_scans; ++i) {
|
for (int i = 0; i < n_scans; ++i) {
|
||||||
values[i] = read_from[i];
|
values[i] = read_from[i];
|
||||||
}
|
}
|
||||||
warp.sync();
|
|
||||||
|
|
||||||
// Perform the scan.
|
// Perform the scan.
|
||||||
for (int i = 0; i < n_scans; ++i) {
|
for (int i = 0; i < n_scans; ++i) {
|
||||||
@@ -436,8 +435,12 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
int64_t stride_blocks = cuda::ceil_div(stride, BN);
|
int64_t stride_blocks = cuda::ceil_div(stride, BN);
|
||||||
dim3 num_blocks = get_2d_grid_dims(
|
dim3 num_blocks = get_2d_grid_dims(
|
||||||
in.shape(), in.strides(), axis_size * stride);
|
in.shape(), in.strides(), axis_size * stride);
|
||||||
|
if (num_blocks.x * stride_blocks <= UINT32_MAX) {
|
||||||
num_blocks.x *= stride_blocks;
|
num_blocks.x *= stride_blocks;
|
||||||
int block_dim = BN / N_READS * WARP_SIZE;
|
} else {
|
||||||
|
num_blocks.y *= stride_blocks;
|
||||||
|
}
|
||||||
|
int block_dim = (BN / N_READS) * WARP_SIZE;
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
|
|||||||
@@ -13,11 +13,6 @@ cuda_skip = {
|
|||||||
"TestBlas.test_gather_mm_sorted",
|
"TestBlas.test_gather_mm_sorted",
|
||||||
# Segmented matmul NYI
|
# Segmented matmul NYI
|
||||||
"TestBlas.test_segmented_mm",
|
"TestBlas.test_segmented_mm",
|
||||||
# Scan NYI
|
|
||||||
"TestArray.test_api",
|
|
||||||
"TestAutograd.test_cumprod_grad",
|
|
||||||
"TestOps.test_scans",
|
|
||||||
"TestOps.test_logcumsumexp",
|
|
||||||
# Hadamard NYI
|
# Hadamard NYI
|
||||||
"TestOps.test_hadamard",
|
"TestOps.test_hadamard",
|
||||||
"TestOps.test_hadamard_grad_vmap",
|
"TestOps.test_hadamard_grad_vmap",
|
||||||
|
|||||||
Reference in New Issue
Block a user