diff --git a/mlx/backend/metal/kernels/layer_norm.metal b/mlx/backend/metal/kernels/layer_norm.metal index 06b8be55f..ea77b53dc 100644 --- a/mlx/backend/metal/kernels/layer_norm.metal +++ b/mlx/backend/metal/kernels/layer_norm.metal @@ -31,6 +31,7 @@ inline void threadgroup_sum( for (int i = 0; i < N; i++) { x[i] = simd_sum(x[i]); } + threadgroup_barrier(mem_flags::mem_threadgroup); if (simd_lane_id == 0) { for (int i = 0; i < N; i++) { xs[N * simd_group_id + i] = x[i];