From 5c3663d4a78127cad5d74441fa9289a8566f8d5a Mon Sep 17 00:00:00 2001 From: Cheng Date: Tue, 8 Jul 2025 01:25:05 +0000 Subject: [PATCH] Fix tests on large arrays --- mlx/backend/cuda/binary.cu | 8 ++++---- mlx/backend/cuda/binary_two.cu | 8 ++++---- mlx/backend/cuda/copy/copy_contiguous.cu | 4 ++-- mlx/backend/cuda/ternary.cu | 2 +- mlx/backend/cuda/unary.cu | 2 +- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index 5f3f29094..8d683790c 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -20,7 +20,7 @@ namespace cg = cooperative_groups; template __global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - int remaining = size - index * N_READS; + IdxT remaining = size - index * N_READS; if (remaining <= 0) { return; } @@ -44,7 +44,7 @@ __global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) { template __global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - int remaining = size - index * N_READS; + IdxT remaining = size - index * N_READS; if (remaining <= 0) { return; } @@ -70,7 +70,7 @@ __global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) { template __global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - int remaining = size - index * N_READS; + IdxT remaining = size - index * N_READS; if (remaining <= 0) { return; } @@ -96,7 +96,7 @@ __global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) { template __global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - int remaining = size - index * N_READS; + IdxT remaining = size - index * N_READS; if (remaining <= 0) { return; } diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index 629758197..bbebb5661 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -21,7 +21,7 @@ template __global__ void binary_two_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - int remaining = size - index * N_READS; + IdxT remaining = size - index * N_READS; if (remaining <= 0) { return; } @@ -52,7 +52,7 @@ template __global__ void binary_two_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - int remaining = size - index * N_READS; + IdxT remaining = size - index * N_READS; if (remaining <= 0) { return; } @@ -85,7 +85,7 @@ template __global__ void binary_two_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - int remaining = size - index * N_READS; + IdxT remaining = size - index * N_READS; if (remaining <= 0) { return; } @@ -118,7 +118,7 @@ template __global__ void binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - int remaining = size - index * N_READS; + IdxT remaining = size - index * N_READS; if (remaining <= 0) { return; } diff --git a/mlx/backend/cuda/copy/copy_contiguous.cu b/mlx/backend/cuda/copy/copy_contiguous.cu index 2868b13a3..e4feed91f 100644 --- a/mlx/backend/cuda/copy/copy_contiguous.cu +++ b/mlx/backend/cuda/copy/copy_contiguous.cu @@ -13,7 +13,7 @@ namespace cg = cooperative_groups; template __global__ void copy_s(const In* in, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - int remaining = size - index * N_READS; + IdxT remaining = size - index * N_READS; if (remaining <= 0) { return; } @@ -37,7 +37,7 @@ __global__ void copy_s(const In* in, Out* out, IdxT size) { template __global__ void copy_v(const In* in, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - int remaining = size - index * N_READS; + IdxT remaining = size - index * N_READS; if (remaining <= 0) { return; } diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index 696eaa8ce..db3b46a78 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -19,7 +19,7 @@ template __global__ void ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - int remaining = size - index * N_READS; + IdxT remaining = size - index * N_READS; if (remaining <= 0) { return; } diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index 581bdbe50..4f6de45b3 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -21,7 +21,7 @@ namespace cg = cooperative_groups; template __global__ void unary_v(const In* in, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - int remaining = size - index * N_READS; + IdxT remaining = size - index * N_READS; if (remaining <= 0) { return; }