mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 23:51:14 +08:00
parent
02fedbf1da
commit
639e06e1f3
@ -606,8 +606,12 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
|
||||
auto slice = nb::cast<nb::slice>(idx);
|
||||
int stride = get_slice_int(nb::getattr(slice, "step"), 1);
|
||||
num_strided_slices += (stride != 1);
|
||||
num_simple_slices_post += (stride == 1);
|
||||
if (stride != 1) {
|
||||
num_strided_slices++;
|
||||
num_simple_slices_post = 0;
|
||||
} else {
|
||||
num_simple_slices_post++;
|
||||
}
|
||||
|
||||
} else if (nb::isinstance<array>(idx)) {
|
||||
have_array = true;
|
||||
|
@ -1210,6 +1210,23 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
np.array([1, 3]),
|
||||
)
|
||||
|
||||
check_slices(
|
||||
np.zeros((3, 4, 5, 3)),
|
||||
np.arange(2 * 4 * 3 * 3).reshape(2, 4, 3, 3),
|
||||
np.array([2, 1]),
|
||||
slice(None, None, None),
|
||||
slice(None, None, 2),
|
||||
slice(None, None, None),
|
||||
)
|
||||
|
||||
check_slices(
|
||||
np.zeros((3, 4, 5, 3)),
|
||||
np.arange(2 * 4 * 3 * 3).reshape(2, 4, 3, 3),
|
||||
np.array([2, 1]),
|
||||
slice(None, None, None),
|
||||
slice(None, None, 2),
|
||||
)
|
||||
|
||||
def test_array_at(self):
|
||||
a = mx.array(1)
|
||||
a = a.at[None].add(1)
|
||||
|
Loading…
Reference in New Issue
Block a user