mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
@@ -1058,6 +1058,29 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
a[2:-2, 2:-2] = 4
|
||||
self.assertEqual(a[2, 2].item(), 4)
|
||||
|
||||
# Check slice array slice
|
||||
check_slices(
|
||||
np.zeros((5, 4, 4)),
|
||||
np.arange(4 * 2 * 3).reshape(4, 2, 3),
|
||||
slice(0, 4),
|
||||
np.array([1, 3]),
|
||||
slice(None, -1),
|
||||
)
|
||||
check_slices(
|
||||
np.zeros((5, 4, 4)),
|
||||
np.arange(4 * 2 * 2).reshape(4, 2, 2),
|
||||
slice(0, 4),
|
||||
np.array([1, 3]),
|
||||
slice(0, 4, 2),
|
||||
)
|
||||
|
||||
check_slices(
|
||||
np.zeros((1, 10, 4)),
|
||||
np.arange(2 * 4).reshape(1, 2, 4),
|
||||
slice(None, None, None),
|
||||
np.array([1, 3]),
|
||||
)
|
||||
|
||||
def test_array_at(self):
|
||||
a = mx.array(1)
|
||||
a = a.at[None].add(1)
|
||||
|
Reference in New Issue
Block a user