diff --git a/mlx/backend/cuda/logsumexp.cu b/mlx/backend/cuda/logsumexp.cu index a85b51583..2afcc7e70 100644 --- a/mlx/backend/cuda/logsumexp.cu +++ b/mlx/backend/cuda/logsumexp.cu @@ -54,7 +54,8 @@ __global__ void logsumexp(const T* in, T* out, int axis_size) { // https://github.com/NVIDIA/online-softmax normalizer = normalizer * softmax_exp(prevmax - maxval); for (int i = 0; i < N_READS; i++) { - normalizer = normalizer + softmax_exp(static_cast(vals[i]) - maxval); + normalizer = + normalizer + softmax_exp(static_cast(vals[i]) - maxval); } }