Fix split optimization for array iterator (#484)

This commit is contained in:
Angelos Katharopoulos 2024-01-18 05:50:25 -08:00 committed by GitHub
parent 78e5f2d17d
commit 9c111f176d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 44 additions and 17 deletions

View File

@ -169,21 +169,9 @@ array::ArrayIterator::ArrayIterator(const array& arr, int idx)
if (arr.ndim() == 0) {
throw std::invalid_argument("Cannot iterate over 0-d array.");
}
// Iterate using split
if (arr.shape(0) > 0 && arr.shape(0) <= 10) {
splits = split(arr, arr.shape(0));
for (auto& arr_i : splits) {
arr_i = squeeze(arr_i, 0);
}
}
}
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
if (idx >= 0 && idx < splits.size()) {
return splits[idx];
}
auto start = std::vector<int>(arr.ndim(), 0);
auto end = arr.shape();
auto shape = arr.shape();

View File

@ -151,7 +151,6 @@ class array {
private:
const array& arr;
int idx;
std::vector<array> splits;
};
ArrayIterator begin() const {

View File

@ -493,6 +493,32 @@ class ArrayAt {
py::object indices_;
};
class ArrayPythonIterator {
public:
ArrayPythonIterator(array x) : idx_(0), x_(std::move(x)) {
if (x_.shape(0) > 0 && x_.shape(0) < 10) {
splits_ = split(x_, x_.shape(0));
}
}
array next() {
if (idx_ >= x_.shape(0)) {
throw py::stop_iteration();
}
if (idx_ >= 0 && idx_ < splits_.size()) {
return squeeze(splits_[idx_++], 0);
}
return *(x_.begin() + idx_++);
}
private:
int idx_;
array x_;
std::vector<array> splits_;
};
void init_array(py::module_& m) {
// Types
py::class_<Dtype>(
@ -539,6 +565,13 @@ void init_array(py::module_& m) {
A helper object to apply updates at specific indices.
)pbdoc");
auto array_iterator_class = py::class_<ArrayPythonIterator>(
m,
"_ArrayIterator",
R"pbdoc(
A helper object to iterate over the 1st dimension of an array.
)pbdoc");
auto array_class = py::class_<array>(
m,
"array",
@ -575,6 +608,16 @@ void init_array(py::module_& m) {
.def("maximum", &ArrayAt::maximum, "value"_a)
.def("minimum", &ArrayAt::minimum, "value"_a);
array_iterator_class
.def(
py::init([](const array& x) { return ArrayPythonIterator(x); }),
"x"_a,
R"pbdoc(
__init__(self, x: array)
)pbdoc")
.def("__next__", &ArrayPythonIterator::next)
.def("__iter__", [](const ArrayPythonIterator& it) { return it; });
array_class
.def_buffer([](array& a) {
// Eval if not already evaled
@ -703,10 +746,7 @@ void init_array(py::module_& m) {
}
return a.shape(0);
})
.def(
"__iter__",
[](const array& a) { return py::make_iterator(a); },
py::keep_alive<0, 1>())
.def("__iter__", [](const array& a) { return ArrayPythonIterator(a); })
.def(
"__add__",
[](const array& a, const ScalarOrArray v) {