mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
Fix split optimization for array iterator (#484)
This commit is contained in:
parent
78e5f2d17d
commit
9c111f176d
@ -169,21 +169,9 @@ array::ArrayIterator::ArrayIterator(const array& arr, int idx)
|
||||
if (arr.ndim() == 0) {
|
||||
throw std::invalid_argument("Cannot iterate over 0-d array.");
|
||||
}
|
||||
|
||||
// Iterate using split
|
||||
if (arr.shape(0) > 0 && arr.shape(0) <= 10) {
|
||||
splits = split(arr, arr.shape(0));
|
||||
for (auto& arr_i : splits) {
|
||||
arr_i = squeeze(arr_i, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
|
||||
if (idx >= 0 && idx < splits.size()) {
|
||||
return splits[idx];
|
||||
}
|
||||
|
||||
auto start = std::vector<int>(arr.ndim(), 0);
|
||||
auto end = arr.shape();
|
||||
auto shape = arr.shape();
|
||||
|
@ -151,7 +151,6 @@ class array {
|
||||
private:
|
||||
const array& arr;
|
||||
int idx;
|
||||
std::vector<array> splits;
|
||||
};
|
||||
|
||||
ArrayIterator begin() const {
|
||||
|
@ -493,6 +493,32 @@ class ArrayAt {
|
||||
py::object indices_;
|
||||
};
|
||||
|
||||
class ArrayPythonIterator {
|
||||
public:
|
||||
ArrayPythonIterator(array x) : idx_(0), x_(std::move(x)) {
|
||||
if (x_.shape(0) > 0 && x_.shape(0) < 10) {
|
||||
splits_ = split(x_, x_.shape(0));
|
||||
}
|
||||
}
|
||||
|
||||
array next() {
|
||||
if (idx_ >= x_.shape(0)) {
|
||||
throw py::stop_iteration();
|
||||
}
|
||||
|
||||
if (idx_ >= 0 && idx_ < splits_.size()) {
|
||||
return squeeze(splits_[idx_++], 0);
|
||||
}
|
||||
|
||||
return *(x_.begin() + idx_++);
|
||||
}
|
||||
|
||||
private:
|
||||
int idx_;
|
||||
array x_;
|
||||
std::vector<array> splits_;
|
||||
};
|
||||
|
||||
void init_array(py::module_& m) {
|
||||
// Types
|
||||
py::class_<Dtype>(
|
||||
@ -539,6 +565,13 @@ void init_array(py::module_& m) {
|
||||
A helper object to apply updates at specific indices.
|
||||
)pbdoc");
|
||||
|
||||
auto array_iterator_class = py::class_<ArrayPythonIterator>(
|
||||
m,
|
||||
"_ArrayIterator",
|
||||
R"pbdoc(
|
||||
A helper object to iterate over the 1st dimension of an array.
|
||||
)pbdoc");
|
||||
|
||||
auto array_class = py::class_<array>(
|
||||
m,
|
||||
"array",
|
||||
@ -575,6 +608,16 @@ void init_array(py::module_& m) {
|
||||
.def("maximum", &ArrayAt::maximum, "value"_a)
|
||||
.def("minimum", &ArrayAt::minimum, "value"_a);
|
||||
|
||||
array_iterator_class
|
||||
.def(
|
||||
py::init([](const array& x) { return ArrayPythonIterator(x); }),
|
||||
"x"_a,
|
||||
R"pbdoc(
|
||||
__init__(self, x: array)
|
||||
)pbdoc")
|
||||
.def("__next__", &ArrayPythonIterator::next)
|
||||
.def("__iter__", [](const ArrayPythonIterator& it) { return it; });
|
||||
|
||||
array_class
|
||||
.def_buffer([](array& a) {
|
||||
// Eval if not already evaled
|
||||
@ -703,10 +746,7 @@ void init_array(py::module_& m) {
|
||||
}
|
||||
return a.shape(0);
|
||||
})
|
||||
.def(
|
||||
"__iter__",
|
||||
[](const array& a) { return py::make_iterator(a); },
|
||||
py::keep_alive<0, 1>())
|
||||
.def("__iter__", [](const array& a) { return ArrayPythonIterator(a); })
|
||||
.def(
|
||||
"__add__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
|
Loading…
Reference in New Issue
Block a user