diff --git a/mlx/backend/cuda/logsumexp.cu b/mlx/backend/cuda/logsumexp.cu index 9e61b9ff0..a85b51583 100644 --- a/mlx/backend/cuda/logsumexp.cu +++ b/mlx/backend/cuda/logsumexp.cu @@ -54,7 +54,7 @@ __global__ void logsumexp(const T* in, T* out, int axis_size) { // https://github.com/NVIDIA/online-softmax normalizer = normalizer * softmax_exp(prevmax - maxval); for (int i = 0; i < N_READS; i++) { - normalizer = normalizer + softmax_exp(vals[i] - maxval); + normalizer = normalizer + softmax_exp(static_cast(vals[i]) - maxval); } } diff --git a/mlx/backend/cuda/softmax.cu b/mlx/backend/cuda/softmax.cu index 362c88431..2ff4464b0 100644 --- a/mlx/backend/cuda/softmax.cu +++ b/mlx/backend/cuda/softmax.cu @@ -57,7 +57,8 @@ __global__ void softmax(const T* in, T* out, int axis_size) { normalizer = normalizer * softmax_exp(prevmax - maxval); #pragma unroll for (int i = 0; i < N_READS; i++) { - normalizer = normalizer + softmax_exp(static_cast(vals[i]) - maxval); + normalizer = + normalizer + softmax_exp(static_cast(vals[i]) - maxval); } }