Fix split optimization for array iterator (#484)

This commit is contained in:
Angelos Katharopoulos
2024-01-18 05:50:25 -08:00
committed by GitHub
parent 78e5f2d17d
commit 9c111f176d
3 changed files with 44 additions and 17 deletions

View File

@@ -169,21 +169,9 @@ array::ArrayIterator::ArrayIterator(const array& arr, int idx)
if (arr.ndim() == 0) {
throw std::invalid_argument("Cannot iterate over 0-d array.");
}
// Iterate using split
if (arr.shape(0) > 0 && arr.shape(0) <= 10) {
splits = split(arr, arr.shape(0));
for (auto& arr_i : splits) {
arr_i = squeeze(arr_i, 0);
}
}
}
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
if (idx >= 0 && idx < splits.size()) {
return splits[idx];
}
auto start = std::vector<int>(arr.ndim(), 0);
auto end = arr.shape();
auto shape = arr.shape();

View File

@@ -151,7 +151,6 @@ class array {
private:
const array& arr;
int idx;
std::vector<array> splits;
};
ArrayIterator begin() const {