// Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" #include #include #include #include #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; template inline __device__ T softmax_exp(T x) { // Softmax doesn't need high precision exponential cause x is gonna be in // (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)). return __expf(x); } template __global__ void logsumexp(const T* in, T* out, int axis_size) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); in += grid.block_rank() * axis_size; cg::greater max_op; cg::plus plus_op; // Thread reduce. AccT prevmax; AccT maxval = Limits::finite_min(); AccT normalizer = 0; for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) { auto index = r * BLOCK_DIM + block.thread_rank(); auto vals = load_vector(in, index, axis_size, Limits::min()); prevmax = maxval; #pragma unroll for (int i = 0; i < N_READS; ++i) { maxval = max_op(maxval, static_cast(vals[i])); } // Online normalizer calculation for softmax: // https://github.com/NVIDIA/online-softmax normalizer = normalizer * softmax_exp(prevmax - maxval); for (int i = 0; i < N_READS; i++) { normalizer = normalizer + softmax_exp(static_cast(vals[i]) - maxval); } } // First warp reduce. prevmax = maxval; maxval = cg::reduce(warp, maxval, max_op); normalizer = normalizer * softmax_exp(prevmax - maxval); normalizer = cg::reduce(warp, normalizer, plus_op); __shared__ AccT local_max[WARP_SIZE]; __shared__ AccT local_normalizer[WARP_SIZE]; // Write to shared memory and do second warp reduce. prevmax = maxval; if (warp.thread_rank() == 0) { local_max[warp.meta_group_rank()] = maxval; } block.sync(); maxval = warp.thread_rank() < warp.meta_group_size() ? local_max[warp.thread_rank()] : Limits::finite_min(); maxval = cg::reduce(warp, maxval, max_op); normalizer = normalizer * softmax_exp(prevmax - maxval); if (warp.thread_rank() == 0) { local_normalizer[warp.meta_group_rank()] = normalizer; } block.sync(); normalizer = warp.thread_rank() < warp.meta_group_size() ? local_normalizer[warp.thread_rank()] : AccT{}; normalizer = cg::reduce(warp, normalizer, plus_op); // Write output. if (block.thread_rank() == 0) { out[grid.block_rank()] = isinf(maxval) ? maxval : log(normalizer) + maxval; } } } // namespace cu void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("LogSumExp::eval_gpu"); assert(inputs.size() == 1); auto& s = stream(); auto& encoder = cu::get_command_encoder(s); // Make sure that the last dimension is contiguous. auto ensure_contiguous = [&s, &encoder](const array& x) { if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { return x; } else { array x_copy = contiguous_copy_gpu(x, s); encoder.add_temporary(x_copy); return x_copy; } }; auto in = ensure_contiguous(inputs[0]); if (in.flags().row_contiguous) { out.set_data(allocator::malloc(out.nbytes())); } else { auto n = in.shape(-1); auto flags = in.flags(); auto strides = in.strides(); for (auto& s : strides) { s /= n; } bool col_contig = strides[0] == 1; for (int i = 1; col_contig && i < strides.size(); ++i) { col_contig &= (out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]); } flags.col_contiguous = col_contig; out.set_data( allocator::malloc(in.nbytes() / n), in.data_size() / n, std::move(strides), flags); } int axis_size = in.shape().back(); int n_rows = in.data_size() / axis_size; encoder.set_input_array(in); encoder.set_output_array(out); dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) { using DataType = cuda_type_t; constexpr int N_READS = 16 / sizeof(DataType); dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { auto kernel = cu::logsumexp; encoder.add_kernel_node( kernel, n_rows, block_dim(), in.data(), out.data(), axis_size); }); }); } } // namespace mlx::core