Split multi output (#461)

* Multi-output split primitive
* Add the multi-output split to the ArrayIterator
* Add some grad tests for split
This commit is contained in:
Angelos Katharopoulos
2024-01-16 13:33:55 -08:00
committed by GitHub
parent 4e290d282f
commit d8fabaa12b
12 changed files with 202 additions and 5 deletions

View File

@@ -727,6 +727,12 @@ void Sinh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "sinh");
}
void Split::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Square::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "square");
}