diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index 9582b0378..9776278ec 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -17,47 +17,134 @@ namespace cu { namespace cg = cooperative_groups; -template +template __global__ void binary_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto out = Op{}(a[0], b[0]); - out_a[0] = out[0]; - out_b[0] = out[1]; + int remaining = size - index * N_READS; + if (remaining <= 0) { + return; + } + + if (remaining < N_READS) { + for (int i = 0; i < remaining; ++i) { + IdxT offset = index * N_READS + i; + auto out = Op{}(a[0], b[0]); + out_a[offset] = out[0]; + out_b[offset] = out[1]; + } + } else { + AlignedVector out_a_vec; + AlignedVector out_b_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + auto out = Op{}(a[0], b[0]); + out_a_vec.val[i] = out[0]; + out_b_vec.val[i] = out[1]; + } + + store_vector(out_a, index, out_a_vec); + store_vector(out_b, index, out_b_vec); } } -template +template __global__ void binary_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto out = Op{}(a[0], b[index]); - out_a[index] = out[0]; - out_b[index] = out[1]; + int remaining = size - index * N_READS; + if (remaining <= 0) { + return; + } + + if (remaining < N_READS) { + for (int i = 0; i < remaining; ++i) { + IdxT offset = index * N_READS + i; + auto out = Op{}(a[0], b[offset]); + out_a[offset] = out[0]; + out_b[offset] = out[1]; + } + } else { + auto b_vec = load_vector(b, index); + + AlignedVector out_a_vec; + AlignedVector out_b_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + auto out = Op{}(a[0], b_vec.val[i]); + out_a_vec.val[i] = out[0]; + out_b_vec.val[i] = out[1]; + } + + store_vector(out_a, index, out_a_vec); + store_vector(out_b, index, out_b_vec); } } -template +template __global__ void binary_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto out = Op{}(a[index], b[0]); - out_a[index] = out[0]; - out_b[index] = out[1]; + int remaining = size - index * N_READS; + if (remaining <= 0) { + return; + } + + if (remaining < N_READS) { + for (int i = 0; i < remaining; ++i) { + IdxT offset = index * N_READS + i; + auto out = Op{}(a[offset], b[0]); + out_a[offset] = out[0]; + out_b[offset] = out[1]; + } + } else { + auto a_vec = load_vector(a, index); + + AlignedVector out_a_vec; + AlignedVector out_b_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + auto out = Op{}(a_vec.val[i], b[0]); + out_a_vec.val[i] = out[0]; + out_b_vec.val[i] = out[1]; + } + + store_vector(out_a, index, out_a_vec); + store_vector(out_b, index, out_b_vec); } } -template +template __global__ void binary_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto out = Op{}(a[index], b[index]); - out_a[index] = out[0]; - out_b[index] = out[1]; + int remaining = size - index * N_READS; + if (remaining <= 0) { + return; + } + + if (remaining < N_READS) { + for (int i = 0; i < remaining; ++i) { + IdxT offset = index * N_READS + i; + auto out = Op{}(a[offset], b[offset]); + out_a[offset] = out[0]; + out_b[offset] = out[1]; + } + } else { + auto a_vec = load_vector(a, index); + auto b_vec = load_vector(b, index); + + AlignedVector out_a_vec; + AlignedVector out_b_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + auto out = Op{}(a_vec.val[i], b_vec.val[i]); + out_a_vec.val[i] = out[0]; + out_b_vec.val[i] = out[1]; + } + + store_vector(out_a, index, out_a_vec); + store_vector(out_b, index, out_b_vec); } } @@ -200,20 +287,23 @@ void binary_op_gpu_inplace( } else { dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) { using IdxT = std::conditional_t; - auto kernel = cu::binary_ss; + // TODO: Choose optimized value based on type size. + constexpr int N_READS = 4; + auto kernel = cu::binary_ss; if (bopt == BinaryOpType::ScalarVector) { - kernel = cu::binary_sv; + kernel = cu::binary_sv; } else if (bopt == BinaryOpType::VectorScalar) { - kernel = cu::binary_vs; + kernel = cu::binary_vs; } else if (bopt == BinaryOpType::VectorVector) { - kernel = cu::binary_vv; + kernel = cu::binary_vv; } auto [num_blocks, block_dims] = get_launch_args( kernel, out_a.data_size(), out_a.shape(), out_a.strides(), - large()); + large(), + N_READS); encoder.add_kernel_node( kernel, num_blocks,