diff --git a/mlx/array.cpp b/mlx/array.cpp index 207496286..ad5d4f122 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -169,21 +169,9 @@ array::ArrayIterator::ArrayIterator(const array& arr, int idx) if (arr.ndim() == 0) { throw std::invalid_argument("Cannot iterate over 0-d array."); } - - // Iterate using split - if (arr.shape(0) > 0 && arr.shape(0) <= 10) { - splits = split(arr, arr.shape(0)); - for (auto& arr_i : splits) { - arr_i = squeeze(arr_i, 0); - } - } } array::ArrayIterator::reference array::ArrayIterator::operator*() const { - if (idx >= 0 && idx < splits.size()) { - return splits[idx]; - } - auto start = std::vector(arr.ndim(), 0); auto end = arr.shape(); auto shape = arr.shape(); diff --git a/mlx/array.h b/mlx/array.h index 52c968c0e..0d530162c 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -151,7 +151,6 @@ class array { private: const array& arr; int idx; - std::vector splits; }; ArrayIterator begin() const { diff --git a/python/src/array.cpp b/python/src/array.cpp index 8f6e1ac4f..407142fda 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -493,6 +493,32 @@ class ArrayAt { py::object indices_; }; +class ArrayPythonIterator { + public: + ArrayPythonIterator(array x) : idx_(0), x_(std::move(x)) { + if (x_.shape(0) > 0 && x_.shape(0) < 10) { + splits_ = split(x_, x_.shape(0)); + } + } + + array next() { + if (idx_ >= x_.shape(0)) { + throw py::stop_iteration(); + } + + if (idx_ >= 0 && idx_ < splits_.size()) { + return squeeze(splits_[idx_++], 0); + } + + return *(x_.begin() + idx_++); + } + + private: + int idx_; + array x_; + std::vector splits_; +}; + void init_array(py::module_& m) { // Types py::class_( @@ -539,6 +565,13 @@ void init_array(py::module_& m) { A helper object to apply updates at specific indices. )pbdoc"); + auto array_iterator_class = py::class_( + m, + "_ArrayIterator", + R"pbdoc( + A helper object to iterate over the 1st dimension of an array. + )pbdoc"); + auto array_class = py::class_( m, "array", @@ -575,6 +608,16 @@ void init_array(py::module_& m) { .def("maximum", &ArrayAt::maximum, "value"_a) .def("minimum", &ArrayAt::minimum, "value"_a); + array_iterator_class + .def( + py::init([](const array& x) { return ArrayPythonIterator(x); }), + "x"_a, + R"pbdoc( + __init__(self, x: array) + )pbdoc") + .def("__next__", &ArrayPythonIterator::next) + .def("__iter__", [](const ArrayPythonIterator& it) { return it; }); + array_class .def_buffer([](array& a) { // Eval if not already evaled @@ -703,10 +746,7 @@ void init_array(py::module_& m) { } return a.shape(0); }) - .def( - "__iter__", - [](const array& a) { return py::make_iterator(a); }, - py::keep_alive<0, 1>()) + .def("__iter__", [](const array& a) { return ArrayPythonIterator(a); }) .def( "__add__", [](const array& a, const ScalarOrArray v) {