Distributed layers (#1270)

This commit is contained in:
Angelos Katharopoulos
2025-03-21 13:52:17 -07:00
committed by GitHub
parent 69e4dd506b
commit 4eef8102c9
10 changed files with 895 additions and 80 deletions

View File

@@ -993,6 +993,9 @@ array concatenate(
throw std::invalid_argument(
"[concatenate] No arrays provided for concatenation");
}
if (arrays.size() == 1) {
return arrays[0];
}
auto ax = normalize_axis_index(axis, arrays[0].ndim(), "[concatenate] ");