From bab5386306d1928d719682104fd96637c891e850 Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 28 Mar 2024 14:35:55 +0900 Subject: [PATCH] Make ops aware of rvalues: astype/as_strided/copy/full (#895) When compositing transforms lots of temporary of arrays will be created and passed to next primitive, and by making ops accepting args by value we can avoid lots of copies of temporary arrays. --- mlx/ops.cpp | 58 ++++++++++++++++++++++++++++++------------------ mlx/ops.h | 27 +++++++++------------- mlx/primitives.h | 8 +++---- 3 files changed, 50 insertions(+), 43 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 2dfe33a05..c1e9c8e0f 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -158,50 +158,64 @@ array linspace( to_stream(s)); } -array astype(const array& a, Dtype dtype, StreamOrDevice s /* = {} */) { +array astype(array a, Dtype dtype, StreamOrDevice s /* = {} */) { if (dtype == a.dtype()) { - return a; + return std::move(a); } + auto copied_shape = a.shape(); // |a| will be moved return array( - a.shape(), dtype, std::make_shared(to_stream(s), dtype), {a}); + std::move(copied_shape), + dtype, + std::make_shared(to_stream(s), dtype), + {std::move(a)}); } array as_strided( - const array& a, + array a, std::vector shape, std::vector strides, size_t offset, StreamOrDevice s /* = {} */) { - // Force the input array to be contiguous - auto x = reshape(a, {-1}, s); + auto copied_shape = shape; // |shape| will be moved + auto dtype = a.dtype(); // |a| will be moved return array( - shape, - a.dtype(), - std::make_shared(to_stream(s), shape, strides, offset), - {x}); + std::move(copied_shape), + dtype, + std::make_shared( + to_stream(s), std::move(shape), std::move(strides), offset), + // Force the input array to be contiguous. + {reshape(std::move(a), {-1}, s)}); } -array copy(const array& a, StreamOrDevice s /* = {} */) { - return array(a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a}); +array copy(array a, StreamOrDevice s /* = {} */) { + auto copied_shape = a.shape(); // |a| will be moved + auto dtype = a.dtype(); + return array( + std::move(copied_shape), + dtype, + std::make_shared(to_stream(s)), + {std::move(a)}); } array full( - const std::vector& shape, - const array& vals, + std::vector shape, + array vals, Dtype dtype, StreamOrDevice s /* = {} */) { - if (std::any_of(shape.begin(), shape.end(), [](auto i) { return i < 0; })) { + if (std::any_of(shape.begin(), shape.end(), [](int i) { return i < 0; })) { throw std::invalid_argument("[full] Negative dimensions not allowed."); } - auto in = broadcast_to(astype(vals, dtype, s), shape, s); - return array(shape, dtype, std::make_shared(to_stream(s)), {in}); + auto copied_shape = shape; // |shape| will be moved + return array( + std::move(copied_shape), + dtype, + std::make_shared(to_stream(s)), + {broadcast_to(astype(std::move(vals), dtype, s), std::move(shape), s)}); } -array full( - const std::vector& shape, - const array& vals, - StreamOrDevice s /* = {} */) { - return full(shape, vals, vals.dtype(), to_stream(s)); +array full(std::vector shape, array vals, StreamOrDevice s /* = {} */) { + auto dtype = vals.dtype(); // |vals| will be moved + return full(std::move(shape), std::move(vals), dtype, to_stream(s)); } array zeros( diff --git a/mlx/ops.h b/mlx/ops.h index ff9963ed7..69059ed75 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -41,40 +41,33 @@ array linspace( StreamOrDevice s = {}); /** Convert an array to the given data type. */ -array astype(const array& a, Dtype dtype, StreamOrDevice s = {}); +array astype(array a, Dtype dtype, StreamOrDevice s = {}); /** Create a view of an array with the given shape and strides. */ array as_strided( - const array& a, + array a, std::vector shape, std::vector strides, size_t offset, StreamOrDevice s = {}); /** Copy another array. */ -array copy(const array& a, StreamOrDevice s = {}); +array copy(array a, StreamOrDevice s = {}); /** Fill an array of the given shape with the given value(s). */ array full( - const std::vector& shape, - const array& vals, + std::vector shape, + array vals, Dtype dtype, StreamOrDevice s = {}); -array full( - const std::vector& shape, - const array& vals, - StreamOrDevice s = {}); +array full(std::vector shape, array vals, StreamOrDevice s = {}); template -array full( - const std::vector& shape, - T val, - Dtype dtype, - StreamOrDevice s = {}) { - return full(shape, array(val, dtype), to_stream(s)); +array full(std::vector shape, T val, Dtype dtype, StreamOrDevice s = {}) { + return full(std::move(shape), array(val, dtype), to_stream(s)); } template -array full(const std::vector& shape, T val, StreamOrDevice s = {}) { - return full(shape, array(val), to_stream(s)); +array full(std::vector shape, T val, StreamOrDevice s = {}) { + return full(std::move(shape), array(val), to_stream(s)); } /** Fill an array of the given shape with zeros. */ diff --git a/mlx/primitives.h b/mlx/primitives.h index 78e759d62..9e36afbd1 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -420,12 +420,12 @@ class AsStrided : public UnaryPrimitive { public: explicit AsStrided( Stream stream, - const std::vector& shape, - const std::vector& strides, + std::vector shape, + std::vector strides, size_t offset) : UnaryPrimitive(stream), - shape_(shape), - strides_(strides), + shape_(std::move(shape)), + strides_(std::move(strides)), offset_(offset){}; void eval_cpu(const std::vector& inputs, array& out) override;