From 5c932c7bb043eaf61868f8d228ba64bbd7dc11e2 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 9 Jul 2025 01:00:13 +0000 Subject: [PATCH] Use uint as index type --- mlx/backend/cuda/binary.cu | 46 +++++------------- mlx/backend/cuda/binary_two.cu | 60 ++++++++---------------- mlx/backend/cuda/copy/copy_contiguous.cu | 30 ++++-------- mlx/backend/cuda/ternary.cu | 13 ++--- mlx/backend/cuda/unary.cu | 14 ++---- 5 files changed, 51 insertions(+), 112 deletions(-) diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index 8d683790c..b5db056d3 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -20,15 +20,10 @@ namespace cg = cooperative_groups; template __global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - IdxT remaining = size - index * N_READS; - if (remaining <= 0) { - return; - } - if (remaining < N_READS) { - for (int i = 0; i < remaining; ++i) { - IdxT offset = index * N_READS + i; - out[offset] = Op{}(a[0], b[0]); + if ((index + 1) * N_READS > size) { + for (int i = index * N_READS; i < size; ++i) { + out[i] = Op{}(a[0], b[0]); } } else { AlignedVector out_vec; @@ -44,15 +39,10 @@ __global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) { template __global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - IdxT remaining = size - index * N_READS; - if (remaining <= 0) { - return; - } - if (remaining < N_READS) { - for (int i = 0; i < remaining; ++i) { - IdxT offset = index * N_READS + i; - out[offset] = Op{}(a[0], b[offset]); + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = Op{}(a[0], b[i]); } } else { auto b_vec = load_vector(b, index); @@ -70,15 +60,10 @@ __global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) { template __global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - IdxT remaining = size - index * N_READS; - if (remaining <= 0) { - return; - } - if (remaining < N_READS) { - for (int i = 0; i < remaining; ++i) { - IdxT offset = index * N_READS + i; - out[offset] = Op{}(a[offset], b[0]); + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = Op{}(a[i], b[0]); } } else { auto a_vec = load_vector(a, index); @@ -96,15 +81,10 @@ __global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) { template __global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - IdxT remaining = size - index * N_READS; - if (remaining <= 0) { - return; - } - if (remaining < N_READS) { - for (int i = 0; i < remaining; ++i) { - IdxT offset = index * N_READS + i; - out[offset] = Op{}(a[offset], b[offset]); + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = Op{}(a[i], b[i]); } } else { auto a_vec = load_vector(a, index); @@ -268,7 +248,7 @@ void binary_op_gpu_inplace( }); } else { dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { - using IdxT = std::conditional_t; + using IdxT = std::conditional_t; // TODO: Choose optimized value based on type size. constexpr int N_READS = 4; auto kernel = cu::binary_ss; diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index bbebb5661..93fb1da59 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -21,17 +21,12 @@ template __global__ void binary_two_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - IdxT remaining = size - index * N_READS; - if (remaining <= 0) { - return; - } - if (remaining < N_READS) { - for (int i = 0; i < remaining; ++i) { - IdxT offset = index * N_READS + i; + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { auto out = Op{}(a[0], b[0]); - out_a[offset] = out[0]; - out_b[offset] = out[1]; + out_a[i] = out[0]; + out_b[i] = out[1]; } } else { AlignedVector out_a_vec; @@ -52,17 +47,12 @@ template __global__ void binary_two_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - IdxT remaining = size - index * N_READS; - if (remaining <= 0) { - return; - } - if (remaining < N_READS) { - for (int i = 0; i < remaining; ++i) { - IdxT offset = index * N_READS + i; - auto out = Op{}(a[0], b[offset]); - out_a[offset] = out[0]; - out_b[offset] = out[1]; + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + auto out = Op{}(a[0], b[i]); + out_a[i] = out[0]; + out_b[i] = out[1]; } } else { auto b_vec = load_vector(b, index); @@ -85,17 +75,12 @@ template __global__ void binary_two_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - IdxT remaining = size - index * N_READS; - if (remaining <= 0) { - return; - } - if (remaining < N_READS) { - for (int i = 0; i < remaining; ++i) { - IdxT offset = index * N_READS + i; - auto out = Op{}(a[offset], b[0]); - out_a[offset] = out[0]; - out_b[offset] = out[1]; + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + auto out = Op{}(a[i], b[0]); + out_a[i] = out[0]; + out_b[i] = out[1]; } } else { auto a_vec = load_vector(a, index); @@ -118,17 +103,12 @@ template __global__ void binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - IdxT remaining = size - index * N_READS; - if (remaining <= 0) { - return; - } - if (remaining < N_READS) { - for (int i = 0; i < remaining; ++i) { - IdxT offset = index * N_READS + i; - auto out = Op{}(a[offset], b[offset]); - out_a[offset] = out[0]; - out_b[offset] = out[1]; + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + auto out = Op{}(a[i], b[i]); + out_a[i] = out[0]; + out_b[i] = out[1]; } } else { auto a_vec = load_vector(a, index); @@ -290,7 +270,7 @@ void binary_two_op_gpu_inplace( }); } else { dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) { - using IdxT = std::conditional_t; + using IdxT = std::conditional_t; // TODO: Choose optimized value based on type size. constexpr int N_READS = 4; auto kernel = cu::binary_two_ss; diff --git a/mlx/backend/cuda/copy/copy_contiguous.cu b/mlx/backend/cuda/copy/copy_contiguous.cu index e4feed91f..4e9eaccb7 100644 --- a/mlx/backend/cuda/copy/copy_contiguous.cu +++ b/mlx/backend/cuda/copy/copy_contiguous.cu @@ -13,21 +13,16 @@ namespace cg = cooperative_groups; template __global__ void copy_s(const In* in, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - IdxT remaining = size - index * N_READS; - if (remaining <= 0) { - return; - } - if (remaining < N_READS) { - for (int i = 0; i < remaining; ++i) { - IdxT offset = index * N_READS + i; - out[offset] = CastOp{}(in[0]); + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = cast_to(in[0]); } } else { AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { - out_vec.val[i] = CastOp{}(in[0]); + out_vec.val[i] = cast_to(in[0]); } store_vector(out, index, out_vec); @@ -37,15 +32,10 @@ __global__ void copy_s(const In* in, Out* out, IdxT size) { template __global__ void copy_v(const In* in, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - IdxT remaining = size - index * N_READS; - if (remaining <= 0) { - return; - } - if (remaining < N_READS) { - for (int i = 0; i < remaining; ++i) { - IdxT offset = index * N_READS + i; - out[offset] = CastOp{}(in[offset]); + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = cast_to(in[i]); } } else { auto in_vec = load_vector(in, index); @@ -53,7 +43,7 @@ __global__ void copy_v(const In* in, Out* out, IdxT size) { AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { - out_vec.val[i] = CastOp{}(in_vec.val[i]); + out_vec.val[i] = cast_to(in_vec.val[i]); } store_vector(out, index, out_vec); @@ -71,10 +61,10 @@ void copy_contiguous( int64_t out_offset) { dispatch_all_types(in.dtype(), [&](auto in_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) { - dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { + dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { using InType = cuda_type_t; using OutType = cuda_type_t; - using IdxT = std::conditional_t; + using IdxT = std::conditional_t; // TODO: Choose optimized value based on type size. constexpr int N_READS = 4; auto kernel = cu::copy_s; diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index db3b46a78..2122ba497 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -19,15 +19,10 @@ template __global__ void ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - IdxT remaining = size - index * N_READS; - if (remaining <= 0) { - return; - } - if (remaining < N_READS) { - for (int i = 0; i < remaining; ++i) { - IdxT offset = index * N_READS + i; - out[offset] = Op{}(a[offset], b[offset], c[offset]); + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = Op{}(a[i], b[i], c[i]); } } else { auto a_vec = load_vector(a, index); @@ -170,7 +165,7 @@ void ternary_op_gpu_inplace( }); } else { dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { - using IdxT = std::conditional_t; + using IdxT = std::conditional_t; // TODO: Choose optimized value based on type size. constexpr int N_READS = 4; auto kernel = cu::ternary_v; diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index 4f6de45b3..6a5fcce11 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -21,15 +21,10 @@ namespace cg = cooperative_groups; template __global__ void unary_v(const In* in, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - IdxT remaining = size - index * N_READS; - if (remaining <= 0) { - return; - } - if (remaining < N_READS) { - for (int i = 0; i < remaining; ++i) { - IdxT offset = index * N_READS + i; - out[offset] = Op{}(in[offset]); + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = Op{}(in[i]); } } else { auto in_vec = load_vector(in, index); @@ -130,10 +125,9 @@ void unary_op_gpu_inplace( using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); if constexpr (cu::supports_unary_op()) { dispatch_bool(large, [&](auto large) { - using IdxT = std::conditional_t; using InType = cuda_type_t; using OutType = cuda_type_t; - using IdxT = std::conditional_t; + using IdxT = std::conditional_t; if (contig) { // TODO: Choose optimized value based on type size. constexpr int N_READS = 4;