mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 01:50:16 +08:00
Distributed layers (#1270)
This commit is contained in:

committed by
GitHub

parent
69e4dd506b
commit
4eef8102c9
@@ -205,8 +205,10 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.flags().row_contiguous ||
|
||||
(allow_col_major_ && in.flags().col_contiguous)) {
|
||||
constexpr size_t extra_bytes = 16384;
|
||||
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
|
||||
(in.flags().row_contiguous ||
|
||||
(allow_col_major_ && in.flags().col_contiguous))) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
copy(in, out, CopyType::General, stream());
|
||||
|
@@ -251,8 +251,10 @@ void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.flags().row_contiguous ||
|
||||
(allow_col_major_ && in.flags().col_contiguous)) {
|
||||
constexpr size_t extra_bytes = 16384;
|
||||
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
|
||||
(in.flags().row_contiguous ||
|
||||
(allow_col_major_ && in.flags().col_contiguous))) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
copy_gpu(in, out, CopyType::General);
|
||||
|
@@ -993,6 +993,9 @@ array concatenate(
|
||||
throw std::invalid_argument(
|
||||
"[concatenate] No arrays provided for concatenation");
|
||||
}
|
||||
if (arrays.size() == 1) {
|
||||
return arrays[0];
|
||||
}
|
||||
|
||||
auto ax = normalize_axis_index(axis, arrays[0].ndim(), "[concatenate] ");
|
||||
|
||||
|
Reference in New Issue
Block a user