mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 10:41:14 +08:00
Fix split optimization for array iterator (#484)
This commit is contained in:
parent
78e5f2d17d
commit
9c111f176d
@ -169,21 +169,9 @@ array::ArrayIterator::ArrayIterator(const array& arr, int idx)
|
|||||||
if (arr.ndim() == 0) {
|
if (arr.ndim() == 0) {
|
||||||
throw std::invalid_argument("Cannot iterate over 0-d array.");
|
throw std::invalid_argument("Cannot iterate over 0-d array.");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Iterate using split
|
|
||||||
if (arr.shape(0) > 0 && arr.shape(0) <= 10) {
|
|
||||||
splits = split(arr, arr.shape(0));
|
|
||||||
for (auto& arr_i : splits) {
|
|
||||||
arr_i = squeeze(arr_i, 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
|
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
|
||||||
if (idx >= 0 && idx < splits.size()) {
|
|
||||||
return splits[idx];
|
|
||||||
}
|
|
||||||
|
|
||||||
auto start = std::vector<int>(arr.ndim(), 0);
|
auto start = std::vector<int>(arr.ndim(), 0);
|
||||||
auto end = arr.shape();
|
auto end = arr.shape();
|
||||||
auto shape = arr.shape();
|
auto shape = arr.shape();
|
||||||
|
@ -151,7 +151,6 @@ class array {
|
|||||||
private:
|
private:
|
||||||
const array& arr;
|
const array& arr;
|
||||||
int idx;
|
int idx;
|
||||||
std::vector<array> splits;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
ArrayIterator begin() const {
|
ArrayIterator begin() const {
|
||||||
|
@ -493,6 +493,32 @@ class ArrayAt {
|
|||||||
py::object indices_;
|
py::object indices_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class ArrayPythonIterator {
|
||||||
|
public:
|
||||||
|
ArrayPythonIterator(array x) : idx_(0), x_(std::move(x)) {
|
||||||
|
if (x_.shape(0) > 0 && x_.shape(0) < 10) {
|
||||||
|
splits_ = split(x_, x_.shape(0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
array next() {
|
||||||
|
if (idx_ >= x_.shape(0)) {
|
||||||
|
throw py::stop_iteration();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (idx_ >= 0 && idx_ < splits_.size()) {
|
||||||
|
return squeeze(splits_[idx_++], 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
return *(x_.begin() + idx_++);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int idx_;
|
||||||
|
array x_;
|
||||||
|
std::vector<array> splits_;
|
||||||
|
};
|
||||||
|
|
||||||
void init_array(py::module_& m) {
|
void init_array(py::module_& m) {
|
||||||
// Types
|
// Types
|
||||||
py::class_<Dtype>(
|
py::class_<Dtype>(
|
||||||
@ -539,6 +565,13 @@ void init_array(py::module_& m) {
|
|||||||
A helper object to apply updates at specific indices.
|
A helper object to apply updates at specific indices.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
|
||||||
|
auto array_iterator_class = py::class_<ArrayPythonIterator>(
|
||||||
|
m,
|
||||||
|
"_ArrayIterator",
|
||||||
|
R"pbdoc(
|
||||||
|
A helper object to iterate over the 1st dimension of an array.
|
||||||
|
)pbdoc");
|
||||||
|
|
||||||
auto array_class = py::class_<array>(
|
auto array_class = py::class_<array>(
|
||||||
m,
|
m,
|
||||||
"array",
|
"array",
|
||||||
@ -575,6 +608,16 @@ void init_array(py::module_& m) {
|
|||||||
.def("maximum", &ArrayAt::maximum, "value"_a)
|
.def("maximum", &ArrayAt::maximum, "value"_a)
|
||||||
.def("minimum", &ArrayAt::minimum, "value"_a);
|
.def("minimum", &ArrayAt::minimum, "value"_a);
|
||||||
|
|
||||||
|
array_iterator_class
|
||||||
|
.def(
|
||||||
|
py::init([](const array& x) { return ArrayPythonIterator(x); }),
|
||||||
|
"x"_a,
|
||||||
|
R"pbdoc(
|
||||||
|
__init__(self, x: array)
|
||||||
|
)pbdoc")
|
||||||
|
.def("__next__", &ArrayPythonIterator::next)
|
||||||
|
.def("__iter__", [](const ArrayPythonIterator& it) { return it; });
|
||||||
|
|
||||||
array_class
|
array_class
|
||||||
.def_buffer([](array& a) {
|
.def_buffer([](array& a) {
|
||||||
// Eval if not already evaled
|
// Eval if not already evaled
|
||||||
@ -703,10 +746,7 @@ void init_array(py::module_& m) {
|
|||||||
}
|
}
|
||||||
return a.shape(0);
|
return a.shape(0);
|
||||||
})
|
})
|
||||||
.def(
|
.def("__iter__", [](const array& a) { return ArrayPythonIterator(a); })
|
||||||
"__iter__",
|
|
||||||
[](const array& a) { return py::make_iterator(a); },
|
|
||||||
py::keep_alive<0, 1>())
|
|
||||||
.def(
|
.def(
|
||||||
"__add__",
|
"__add__",
|
||||||
[](const array& a, const ScalarOrArray v) {
|
[](const array& a, const ScalarOrArray v) {
|
||||||
|
Loading…
Reference in New Issue
Block a user