diff --git a/mlx/backend/cuda/logsumexp.cu b/mlx/backend/cuda/logsumexp.cu index ba5836a331..2afcc7e705 100644 --- a/mlx/backend/cuda/logsumexp.cu +++ b/mlx/backend/cuda/logsumexp.cu @@ -43,20 +43,19 @@ __global__ void logsumexp(const T* in, T* out, int axis_size) { AccT maxval = Limits::finite_min(); AccT normalizer = 0; for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) { - AccT vals[N_READS]; - cub::LoadDirectBlocked( - r * BLOCK_DIM + block.thread_rank(), - make_cast_iterator(in), - vals, - axis_size, - Limits::min()); + auto index = r * BLOCK_DIM + block.thread_rank(); + auto vals = load_vector(in, index, axis_size, Limits::min()); prevmax = maxval; - maxval = max_op(maxval, cub::ThreadReduce(vals, max_op)); +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + maxval = max_op(maxval, static_cast(vals[i])); + } // Online normalizer calculation for softmax: // https://github.com/NVIDIA/online-softmax normalizer = normalizer * softmax_exp(prevmax - maxval); for (int i = 0; i < N_READS; i++) { - normalizer = normalizer + softmax_exp(vals[i] - maxval); + normalizer = + normalizer + softmax_exp(static_cast(vals[i]) - maxval); } } @@ -143,9 +142,9 @@ void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(in); encoder.set_output_array(out); dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) { - constexpr int N_READS = 4; + using DataType = cuda_type_t; + constexpr int N_READS = 16 / sizeof(DataType); dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - using DataType = cuda_type_t; auto kernel = cu::logsumexp; encoder.add_kernel_node( kernel, diff --git a/mlx/backend/cuda/softmax.cu b/mlx/backend/cuda/softmax.cu index 53615ae4de..2ff4464b01 100644 --- a/mlx/backend/cuda/softmax.cu +++ b/mlx/backend/cuda/softmax.cu @@ -11,7 +11,6 @@ #include #include #include -#include #include @@ -45,20 +44,21 @@ __global__ void softmax(const T* in, T* out, int axis_size) { AccT maxval = Limits::finite_min(); AccT normalizer = cast_to(0); for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) { - AccT vals[N_READS]; - cub::LoadDirectBlocked( - r * BLOCK_DIM + block.thread_rank(), - make_cast_iterator(in), - vals, - axis_size, - Limits::min()); + auto index = r * BLOCK_DIM + block.thread_rank(); + auto vals = load_vector(in, index, axis_size, Limits::min()); prevmax = maxval; - maxval = max_op(maxval, cub::ThreadReduce(vals, max_op)); +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + maxval = max_op(maxval, static_cast(vals[i])); + } + // Online normalizer calculation for softmax: // https://github.com/NVIDIA/online-softmax normalizer = normalizer * softmax_exp(prevmax - maxval); +#pragma unroll for (int i = 0; i < N_READS; i++) { - normalizer = normalizer + softmax_exp(vals[i] - maxval); + normalizer = + normalizer + softmax_exp(static_cast(vals[i]) - maxval); } } @@ -95,12 +95,11 @@ __global__ void softmax(const T* in, T* out, int axis_size) { // Write output. for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) { auto index = r * BLOCK_DIM + block.thread_rank(); - T vals[N_READS]; - cub::LoadDirectBlocked(index, in, vals, axis_size); + auto vals = load_vector(in, index, axis_size, T(0)); for (int i = 0; i < N_READS; i++) { vals[i] = softmax_exp(static_cast(vals[i]) - maxval) * normalizer; } - cub::StoreDirectBlocked(index, out, vals, axis_size); + store_vector(out, index, vals, axis_size); } } @@ -141,9 +140,9 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(in); encoder.set_output_array(out); dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) { - constexpr int N_READS = 4; + using DataType = cuda_type_t; + constexpr int N_READS = 16 / sizeof(DataType); dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { - using DataType = cuda_type_t; auto kernel = cu::softmax; if (precise) { kernel = cu::softmax;