mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 17:39:05 +08:00
Fix split optimization for array iterator (#484)
This commit is contained in:
committed by
GitHub
parent
78e5f2d17d
commit
9c111f176d
@@ -169,21 +169,9 @@ array::ArrayIterator::ArrayIterator(const array& arr, int idx)
|
||||
if (arr.ndim() == 0) {
|
||||
throw std::invalid_argument("Cannot iterate over 0-d array.");
|
||||
}
|
||||
|
||||
// Iterate using split
|
||||
if (arr.shape(0) > 0 && arr.shape(0) <= 10) {
|
||||
splits = split(arr, arr.shape(0));
|
||||
for (auto& arr_i : splits) {
|
||||
arr_i = squeeze(arr_i, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
|
||||
if (idx >= 0 && idx < splits.size()) {
|
||||
return splits[idx];
|
||||
}
|
||||
|
||||
auto start = std::vector<int>(arr.ndim(), 0);
|
||||
auto end = arr.shape();
|
||||
auto shape = arr.shape();
|
||||
|
||||
Reference in New Issue
Block a user