mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
Fix split optimization for array iterator (#484)
This commit is contained in:

committed by
GitHub

parent
78e5f2d17d
commit
9c111f176d
@@ -169,21 +169,9 @@ array::ArrayIterator::ArrayIterator(const array& arr, int idx)
|
||||
if (arr.ndim() == 0) {
|
||||
throw std::invalid_argument("Cannot iterate over 0-d array.");
|
||||
}
|
||||
|
||||
// Iterate using split
|
||||
if (arr.shape(0) > 0 && arr.shape(0) <= 10) {
|
||||
splits = split(arr, arr.shape(0));
|
||||
for (auto& arr_i : splits) {
|
||||
arr_i = squeeze(arr_i, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
|
||||
if (idx >= 0 && idx < splits.size()) {
|
||||
return splits[idx];
|
||||
}
|
||||
|
||||
auto start = std::vector<int>(arr.ndim(), 0);
|
||||
auto end = arr.shape();
|
||||
auto shape = arr.shape();
|
||||
|
@@ -151,7 +151,6 @@ class array {
|
||||
private:
|
||||
const array& arr;
|
||||
int idx;
|
||||
std::vector<array> splits;
|
||||
};
|
||||
|
||||
ArrayIterator begin() const {
|
||||
|
Reference in New Issue
Block a user