Split multi output (#461)

* Multi-output split primitive
* Add the multi-output split to the ArrayIterator
* Add some grad tests for split
This commit is contained in:
Angelos Katharopoulos
2024-01-16 13:33:55 -08:00
committed by GitHub
parent 4e290d282f
commit d8fabaa12b
12 changed files with 202 additions and 5 deletions

View File

@@ -127,11 +127,7 @@ class array {
using value_type = const array;
using reference = value_type;
explicit ArrayIterator(const array& arr, int idx = 0) : arr(arr), idx(idx) {
if (arr.ndim() == 0) {
throw std::invalid_argument("Cannot iterate over 0-d array.");
}
}
explicit ArrayIterator(const array& arr, int idx = 0);
reference operator*() const;
@@ -155,6 +151,7 @@ class array {
private:
const array& arr;
int idx;
std::vector<array> splits;
};
ArrayIterator begin() const {