From 7f3f8d8f8d081fc79456565b07a16b7a2f2da520 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 9 Feb 2024 17:02:13 -0800 Subject: [PATCH] Fix the softmax fix (#661) --- mlx/backend/accelerate/softmax.cpp | 2 +- mlx/backend/common/softmax.cpp | 2 +- mlx/backend/metal/softmax.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mlx/backend/accelerate/softmax.cpp b/mlx/backend/accelerate/softmax.cpp index 9e7ddf632a..8b95e32d46 100644 --- a/mlx/backend/accelerate/softmax.cpp +++ b/mlx/backend/accelerate/softmax.cpp @@ -276,7 +276,7 @@ void Softmax::eval_cpu(const std::vector& inputs, array& out) { auto check_input = [](array x) { bool no_copy = x.strides()[x.ndim() - 1] == 1; if (x.ndim() > 1) { - auto s = x.strides()[x.ndim() - 1]; + auto s = x.strides()[x.ndim() - 2]; no_copy &= (s == 0 || s == x.shape().back()); } if (no_copy) { diff --git a/mlx/backend/common/softmax.cpp b/mlx/backend/common/softmax.cpp index 564fd1f227..87ce748c84 100644 --- a/mlx/backend/common/softmax.cpp +++ b/mlx/backend/common/softmax.cpp @@ -55,7 +55,7 @@ void Softmax::eval(const std::vector& inputs, array& out) { auto check_input = [](array x) { bool no_copy = x.strides()[x.ndim() - 1] == 1; if (x.ndim() > 1) { - auto s = x.strides()[x.ndim() - 1]; + auto s = x.strides()[x.ndim() - 2]; no_copy &= (s == 0 || s == x.shape().back()); } if (no_copy) { diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index 7edc91b55c..be25bc032f 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -24,7 +24,7 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { auto check_input = [&copies, &s](const array& x) { bool no_copy = x.strides()[x.ndim() - 1] == 1; if (x.ndim() > 1) { - auto s = x.strides()[x.ndim() - 1]; + auto s = x.strides()[x.ndim() - 2]; no_copy &= (s == 0 || s == x.shape().back()); } if (no_copy) {