Distributed layers (#1270)

This commit is contained in:
Angelos Katharopoulos
2025-03-21 13:52:17 -07:00
committed by GitHub
parent 69e4dd506b
commit 4eef8102c9
10 changed files with 895 additions and 80 deletions

View File

@@ -205,8 +205,10 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous)) {
constexpr size_t extra_bytes = 16384;
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in);
} else {
copy(in, out, CopyType::General, stream());

View File

@@ -251,8 +251,10 @@ void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous)) {
constexpr size_t extra_bytes = 16384;
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in);
} else {
copy_gpu(in, out, CopyType::General);

View File

@@ -993,6 +993,9 @@ array concatenate(
throw std::invalid_argument(
"[concatenate] No arrays provided for concatenation");
}
if (arrays.size() == 1) {
return arrays[0];
}
auto ax = normalize_axis_index(axis, arrays[0].ndim(), "[concatenate] ");