[Fix] expand axes for dimension with integer indices in mlx_slice_update (#1035)

* Not sure if this is correct

* Format

* Edit tests

* Add negative test

* Format

* add one more test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Jacket 2024-04-29 09:57:28 -05:00 committed by GitHub
parent c4a471c99d
commit 490c0c4fdc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 3 deletions

View File

@ -845,9 +845,7 @@ auto mlx_slice_update(
st = (st < 0) ? st + src.shape(ax) : st; st = (st < 0) ? st + src.shape(ax) : st;
starts[ax] = st; starts[ax] = st;
stops[ax] = st + 1; stops[ax] = st + 1;
if (src.ndim() - ax < up.ndim()) { upd_expand_dims.push_back(ax);
upd_expand_dims.push_back(ax - src.ndim());
}
ax++; ax++;
} }
} }

View File

@ -1194,6 +1194,10 @@ class TestArray(mlx_tests.MLXTestCase):
a = mx.zeros((2, 2)) a = mx.zeros((2, 2))
a[0, 0, 0] = 1 a[0, 0, 0] = 1
with self.assertRaises(ValueError):
a = mx.zeros((5, 4, 3))
a[:, 0] = mx.ones((5, 1, 3))
check_slices(np.zeros((2, 2, 2, 2)), 1, None, Ellipsis, None) check_slices(np.zeros((2, 2, 2, 2)), 1, None, Ellipsis, None)
check_slices( check_slices(
np.zeros((2, 2, 2, 2)), 1, np.array([0, 1]), Ellipsis, np.array([0, 1]) np.zeros((2, 2, 2, 2)), 1, np.array([0, 1]), Ellipsis, np.array([0, 1])
@ -1251,6 +1255,13 @@ class TestArray(mlx_tests.MLXTestCase):
slice(None, None, 2), slice(None, None, 2),
) )
check_slices(np.zeros((5, 4, 3)), np.ones((5, 3)), slice(None), 0)
check_slices(np.zeros((5, 4, 3)), np.ones((5, 1, 3)), slice(None), slice(0, 1))
check_slices(
np.ones((3, 4, 4, 4)), np.zeros((4, 4)), 0, slice(0, 4), 3, slice(0, 4)
)
def test_array_at(self): def test_array_at(self):
a = mx.array(1) a = mx.array(1)
a = a.at[None].add(1) a = a.at[None].add(1)