This commit is contained in:
Awni Hannun 2024-02-09 16:50:45 -08:00 committed by GitHub
parent b670485185
commit b96be943dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 23 additions and 3 deletions

View File

@ -274,7 +274,12 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous
auto check_input = [](array x) {
if (x.strides()[x.ndim() - 1] == 1) {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 1];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});

View File

@ -53,7 +53,12 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous
auto check_input = [](array x) {
if (x.strides().back() == 1) {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 1];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});

View File

@ -22,7 +22,12 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous
std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) {
if (x.strides()[x.ndim() - 1] == 1) {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 1];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});

View File

@ -1386,6 +1386,11 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue((a[:-1] < 1e-9).all())
self.assertEqual(a[-1], 1)
# Sliced inputs
y = mx.random.uniform(shape=(8, 4))
out = mx.softmax(y[:, 0:2], axis=-1)
self.assertAlmostEqual(out.sum().item(), 8.0)
def test_concatenate(self):
a_npy = np.random.randn(32, 32, 32)
b_npy = np.random.randn(32, 32, 32)