From c41f7565ed373114e87493fe76c30efddb73b322 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 3 Apr 2025 08:32:59 -0700 Subject: [PATCH] fix softmax / logsumexp (#2042) --- mlx/backend/metal/kernels/logsumexp.h | 1 + mlx/backend/metal/kernels/softmax.h | 1 + 2 files changed, 2 insertions(+) diff --git a/mlx/backend/metal/kernels/logsumexp.h b/mlx/backend/metal/kernels/logsumexp.h index 374bbcd41..b6898e31e 100644 --- a/mlx/backend/metal/kernels/logsumexp.h +++ b/mlx/backend/metal/kernels/logsumexp.h @@ -33,6 +33,7 @@ template local_max[simd_lane_id] = Limits::min; local_normalizer[simd_lane_id] = 0; } + threadgroup_barrier(mem_flags::mem_threadgroup); // Get the max AccT maxval = Limits::finite_min; diff --git a/mlx/backend/metal/kernels/softmax.h b/mlx/backend/metal/kernels/softmax.h index 43e593d0e..b36b73bd8 100644 --- a/mlx/backend/metal/kernels/softmax.h +++ b/mlx/backend/metal/kernels/softmax.h @@ -40,6 +40,7 @@ template local_max[simd_lane_id] = Limits::min; local_normalizer[simd_lane_id] = 0; } + threadgroup_barrier(mem_flags::mem_threadgroup); // Get the max AccT maxval = Limits::finite_min;