diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index aa6523f27..d2ba3e6a0 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -15,12 +15,33 @@ namespace cu { namespace cg = cooperative_groups; -template +template __global__ void ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - out[index] = Op{}(a[index], b[index], c[index]); + int remaining = size - index * N_READS; + if (remaining <= 0) { + return; + } + + if (remaining < N_READS) { + for (int i = 0; i < remaining; ++i) { + IdxT offset = index * N_READS + i; + out[offset] = Op{}(a[offset], b[offset], c[offset]); + } + } else { + auto a_vec = load_vector(a, index); + auto b_vec = load_vector(b, index); + auto c_vec = load_vector(c, index); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec.val[i] = CastOp{}(a_vec.val[i]); + out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i], c_vec.val[i]); + } + + store_vector(out, index, out_vec); } } @@ -151,9 +172,16 @@ void ternary_op_gpu_inplace( } else { dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { using IdxT = std::conditional_t; - auto kernel = cu::ternary_v; + // TODO: Choose optimized value based on type size. + constexpr int N_READS = 4; + auto kernel = cu::ternary_v; auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), large()); + kernel, + out.data_size(), + out.shape(), + out.strides(), + large(), + N_READS); encoder.add_kernel_node( kernel, num_blocks,