diff --git a/mlx/backend/accelerate/softmax.cpp b/mlx/backend/accelerate/softmax.cpp index fcd8fbe50..9e7ddf632 100644 --- a/mlx/backend/accelerate/softmax.cpp +++ b/mlx/backend/accelerate/softmax.cpp @@ -274,7 +274,12 @@ void Softmax::eval_cpu(const std::vector& inputs, array& out) { // Make sure that the last dimension is contiguous auto check_input = [](array x) { - if (x.strides()[x.ndim() - 1] == 1) { + bool no_copy = x.strides()[x.ndim() - 1] == 1; + if (x.ndim() > 1) { + auto s = x.strides()[x.ndim() - 1]; + no_copy &= (s == 0 || s == x.shape().back()); + } + if (no_copy) { return x; } else { array x_copy(x.shape(), x.dtype(), nullptr, {}); diff --git a/mlx/backend/common/softmax.cpp b/mlx/backend/common/softmax.cpp index 90874c72d..564fd1f22 100644 --- a/mlx/backend/common/softmax.cpp +++ b/mlx/backend/common/softmax.cpp @@ -53,7 +53,12 @@ void Softmax::eval(const std::vector& inputs, array& out) { // Make sure that the last dimension is contiguous auto check_input = [](array x) { - if (x.strides().back() == 1) { + bool no_copy = x.strides()[x.ndim() - 1] == 1; + if (x.ndim() > 1) { + auto s = x.strides()[x.ndim() - 1]; + no_copy &= (s == 0 || s == x.shape().back()); + } + if (no_copy) { return x; } else { array x_copy(x.shape(), x.dtype(), nullptr, {}); diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index 33ec8014c..7edc91b55 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -22,7 +22,12 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { // Make sure that the last dimension is contiguous std::vector copies; auto check_input = [&copies, &s](const array& x) { - if (x.strides()[x.ndim() - 1] == 1) { + bool no_copy = x.strides()[x.ndim() - 1] == 1; + if (x.ndim() > 1) { + auto s = x.strides()[x.ndim() - 1]; + no_copy &= (s == 0 || s == x.shape().back()); + } + if (no_copy) { return x; } else { array x_copy(x.shape(), x.dtype(), nullptr, {}); diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 9ad6d5a53..433890237 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1386,6 +1386,11 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue((a[:-1] < 1e-9).all()) self.assertEqual(a[-1], 1) + # Sliced inputs + y = mx.random.uniform(shape=(8, 4)) + out = mx.softmax(y[:, 0:2], axis=-1) + self.assertAlmostEqual(out.sum().item(), 8.0) + def test_concatenate(self): a_npy = np.random.randn(32, 32, 32) b_npy = np.random.randn(32, 32, 32)