mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Split multi output (#461)
* Multi-output split primitive * Add the multi-output split to the ArrayIterator * Add some grad tests for split
This commit is contained in:
committed by
GitHub
parent
4e290d282f
commit
d8fabaa12b
@@ -127,11 +127,7 @@ class array {
|
||||
using value_type = const array;
|
||||
using reference = value_type;
|
||||
|
||||
explicit ArrayIterator(const array& arr, int idx = 0) : arr(arr), idx(idx) {
|
||||
if (arr.ndim() == 0) {
|
||||
throw std::invalid_argument("Cannot iterate over 0-d array.");
|
||||
}
|
||||
}
|
||||
explicit ArrayIterator(const array& arr, int idx = 0);
|
||||
|
||||
reference operator*() const;
|
||||
|
||||
@@ -155,6 +151,7 @@ class array {
|
||||
private:
|
||||
const array& arr;
|
||||
int idx;
|
||||
std::vector<array> splits;
|
||||
};
|
||||
|
||||
ArrayIterator begin() const {
|
||||
|
||||
Reference in New Issue
Block a user