From ea451af9a05872242b98ba716d025aedd7daf88f Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Wed, 11 Jun 2025 09:58:15 -0700 Subject: [PATCH] Update no copy condition in normalization to account for axis size 1 --- mlx/backend/metal/normalization.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index d570bf3c0..8674eff72 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -26,7 +26,7 @@ void RMSNorm::eval_gpu( bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; if (no_copy && x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; - no_copy &= (s == 0 || s == x.shape().back()); + no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1); } if (no_copy) { if (x.is_donatable()) { @@ -227,7 +227,7 @@ void LayerNorm::eval_gpu( bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; if (no_copy && x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; - no_copy &= (s == 0 || s == x.shape().back()); + no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1); } if (no_copy) { if (x.is_donatable()) {