diff --git a/mlx/backend/metal/kernels/logsumexp.h b/mlx/backend/metal/kernels/logsumexp.h index 374bbcd41..b6898e31e 100644 --- a/mlx/backend/metal/kernels/logsumexp.h +++ b/mlx/backend/metal/kernels/logsumexp.h @@ -33,6 +33,7 @@ template local_max[simd_lane_id] = Limits::min; local_normalizer[simd_lane_id] = 0; } + threadgroup_barrier(mem_flags::mem_threadgroup); // Get the max AccT maxval = Limits::finite_min; diff --git a/mlx/backend/metal/kernels/softmax.h b/mlx/backend/metal/kernels/softmax.h index 43e593d0e..b36b73bd8 100644 --- a/mlx/backend/metal/kernels/softmax.h +++ b/mlx/backend/metal/kernels/softmax.h @@ -40,6 +40,7 @@ template local_max[simd_lane_id] = Limits::min; local_normalizer[simd_lane_id] = 0; } + threadgroup_barrier(mem_flags::mem_threadgroup); // Get the max AccT maxval = Limits::finite_min;