mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 08:41:13 +08:00
parent
02fedbf1da
commit
639e06e1f3
@ -606,8 +606,12 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
|||||||
|
|
||||||
auto slice = nb::cast<nb::slice>(idx);
|
auto slice = nb::cast<nb::slice>(idx);
|
||||||
int stride = get_slice_int(nb::getattr(slice, "step"), 1);
|
int stride = get_slice_int(nb::getattr(slice, "step"), 1);
|
||||||
num_strided_slices += (stride != 1);
|
if (stride != 1) {
|
||||||
num_simple_slices_post += (stride == 1);
|
num_strided_slices++;
|
||||||
|
num_simple_slices_post = 0;
|
||||||
|
} else {
|
||||||
|
num_simple_slices_post++;
|
||||||
|
}
|
||||||
|
|
||||||
} else if (nb::isinstance<array>(idx)) {
|
} else if (nb::isinstance<array>(idx)) {
|
||||||
have_array = true;
|
have_array = true;
|
||||||
|
@ -1210,6 +1210,23 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
np.array([1, 3]),
|
np.array([1, 3]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
check_slices(
|
||||||
|
np.zeros((3, 4, 5, 3)),
|
||||||
|
np.arange(2 * 4 * 3 * 3).reshape(2, 4, 3, 3),
|
||||||
|
np.array([2, 1]),
|
||||||
|
slice(None, None, None),
|
||||||
|
slice(None, None, 2),
|
||||||
|
slice(None, None, None),
|
||||||
|
)
|
||||||
|
|
||||||
|
check_slices(
|
||||||
|
np.zeros((3, 4, 5, 3)),
|
||||||
|
np.arange(2 * 4 * 3 * 3).reshape(2, 4, 3, 3),
|
||||||
|
np.array([2, 1]),
|
||||||
|
slice(None, None, None),
|
||||||
|
slice(None, None, 2),
|
||||||
|
)
|
||||||
|
|
||||||
def test_array_at(self):
|
def test_array_at(self):
|
||||||
a = mx.array(1)
|
a = mx.array(1)
|
||||||
a = a.at[None].add(1)
|
a = a.at[None].add(1)
|
||||||
|
Loading…
Reference in New Issue
Block a user