mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
[Fix] expand axes for dimension with integer indices in mlx_slice_update (#1035)
* Not sure if this is correct * Format * Edit tests * Add negative test * Format * add one more test --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
c4a471c99d
commit
490c0c4fdc
@ -845,9 +845,7 @@ auto mlx_slice_update(
|
||||
st = (st < 0) ? st + src.shape(ax) : st;
|
||||
starts[ax] = st;
|
||||
stops[ax] = st + 1;
|
||||
if (src.ndim() - ax < up.ndim()) {
|
||||
upd_expand_dims.push_back(ax - src.ndim());
|
||||
}
|
||||
upd_expand_dims.push_back(ax);
|
||||
ax++;
|
||||
}
|
||||
}
|
||||
|
@ -1194,6 +1194,10 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
a = mx.zeros((2, 2))
|
||||
a[0, 0, 0] = 1
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
a = mx.zeros((5, 4, 3))
|
||||
a[:, 0] = mx.ones((5, 1, 3))
|
||||
|
||||
check_slices(np.zeros((2, 2, 2, 2)), 1, None, Ellipsis, None)
|
||||
check_slices(
|
||||
np.zeros((2, 2, 2, 2)), 1, np.array([0, 1]), Ellipsis, np.array([0, 1])
|
||||
@ -1251,6 +1255,13 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
slice(None, None, 2),
|
||||
)
|
||||
|
||||
check_slices(np.zeros((5, 4, 3)), np.ones((5, 3)), slice(None), 0)
|
||||
|
||||
check_slices(np.zeros((5, 4, 3)), np.ones((5, 1, 3)), slice(None), slice(0, 1))
|
||||
check_slices(
|
||||
np.ones((3, 4, 4, 4)), np.zeros((4, 4)), 0, slice(0, 4), 3, slice(0, 4)
|
||||
)
|
||||
|
||||
def test_array_at(self):
|
||||
a = mx.array(1)
|
||||
a = a.at[None].add(1)
|
||||
|
Loading…
Reference in New Issue
Block a user