From f0ae00da122b677021afd97004b08508f94008f3 Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 22 Mar 2024 22:29:16 +0900 Subject: [PATCH] Reduce implicit copies in make_array (#874) 1. Move shapes into outputs instead of copying them. 2. Pass primitive by const ref as it is always copied into outputs, which removes a copy when calling make_array. --- mlx/array.cpp | 11 ++++++----- mlx/array.h | 4 ++-- mlx/compile.cpp | 6 +++--- mlx/ops.cpp | 9 ++++++--- mlx/transforms.cpp | 2 +- 5 files changed, 18 insertions(+), 14 deletions(-) diff --git a/mlx/array.cpp b/mlx/array.cpp index c2982d05c..e14f7bec7 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -48,15 +48,16 @@ array::array( std::move(inputs))) {} std::vector array::make_arrays( - const std::vector>& shapes, + std::vector> shapes, const std::vector& dtypes, - std::shared_ptr primitive, + const std::shared_ptr& primitive, const std::vector& inputs) { std::vector outputs; - for (int i = 0; i < shapes.size(); ++i) { - outputs.push_back(array(shapes[i], dtypes[i], primitive, inputs)); + for (size_t i = 0; i < shapes.size(); ++i) { + outputs.emplace_back(std::move(shapes[i]), dtypes[i], primitive, inputs); } - for (int i = 0; i < outputs.size(); ++i) { + // For each node in |outputs|, its siblings are the other nodes. + for (size_t i = 0; i < outputs.size(); ++i) { auto siblings = outputs; siblings.erase(siblings.begin() + i); outputs[i].set_siblings(std::move(siblings), i); diff --git a/mlx/array.h b/mlx/array.h index c5a76d968..686dc9b33 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -180,9 +180,9 @@ class array { std::vector inputs); static std::vector make_arrays( - const std::vector>& shapes, + std::vector> shapes, const std::vector& dtypes, - std::shared_ptr primitive, + const std::shared_ptr& primitive, const std::vector& inputs); /** A unique identifier for an array. */ diff --git a/mlx/compile.cpp b/mlx/compile.cpp index ed776361a..13b12fee3 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -645,7 +645,7 @@ void compile_fuse( } } auto compiled_outputs = array::make_arrays( - shapes, + std::move(shapes), types, std::make_shared( old_outputs.back().primitive().stream(), @@ -738,8 +738,8 @@ std::vector compile_replace( shapes.push_back(o.shape()); } } - auto real_out = - array::make_arrays(shapes, types, a.primitive_ptr(), real_inputs); + auto real_out = array::make_arrays( + std::move(shapes), types, a.primitive_ptr(), real_inputs); for (int i = 0; i < trace_out.size(); ++i) { trace_to_real.insert({trace_out[i].id(), std::move(real_out[i])}); } diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 175fe6f89..2050fa374 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -614,7 +614,7 @@ std::vector split( shapes.back()[ax] = a.shape(ax) - indices.back(); return array::make_arrays( - shapes, + std::move(shapes), dtypes, std::make_shared(to_stream(s), indices, ax), {a}); @@ -1949,7 +1949,7 @@ divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) { return array::make_arrays( {inputs[0].shape(), inputs[0].shape()}, {inputs[0].dtype(), inputs[0].dtype()}, - std::make_unique(to_stream(s)), + std::make_shared(to_stream(s)), inputs); } @@ -3531,7 +3531,10 @@ std::vector depends( } return array::make_arrays( - shapes, dtypes, std::make_shared(to_stream(s)), all_inputs); + std::move(shapes), + dtypes, + std::make_shared(to_stream(s)), + all_inputs); } array atleast_1d(const array& a, StreamOrDevice s /* = {} */) { diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 03e05de1d..107181f78 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -782,7 +782,7 @@ std::function(const std::vector&)> custom_vjp( } return array::make_arrays( - shapes, + std::move(shapes), dtypes, std::make_shared(to_stream(s), fun_vjp), inputs);