From 73a8c090e02cb199e05db42897411cb5225ae7bf Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 20 Mar 2024 23:54:30 +0900 Subject: [PATCH] Pass shape and inputs by value in array's constructor (#853) Since the shape and inputs are always saved as copy in ArrayDesc, we can unify array's constructors to just take the arguments by value. There are 2 cases: 1. When shape is a lvalue, it will be copied into array's constructor and then moved into ArrayDesc's member. So only 1 copy happens. 2. When shape is a rvalue, it will be moved into array's constructor and then moved into ArrayDesc's member. So no copy happens. So having 1 constructor that takes by value is equivalent to having 2 constructors that const reference and rvalue separately. --- mlx/array.cpp | 40 +++++++--------------------------------- mlx/array.h | 34 +++++++++++----------------------- 2 files changed, 18 insertions(+), 56 deletions(-) diff --git a/mlx/array.cpp b/mlx/array.cpp index 864c6973df..c2982d05c7 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -36,22 +36,11 @@ array::array(const std::complex& val, Dtype dtype /* = complex64 */) init(&cval); } -array::array( - const std::vector& shape, - Dtype dtype, - std::shared_ptr primitive, - const std::vector& inputs) - : array_desc_(std::make_shared( - shape, - dtype, - std::move(primitive), - inputs)) {} - array::array( std::vector shape, Dtype dtype, std::shared_ptr primitive, - std::vector&& inputs) + std::vector inputs) : array_desc_(std::make_shared( std::move(shape), dtype, @@ -92,10 +81,10 @@ array::array(std::initializer_list data, Dtype dtype) /* Build an array from a shared buffer */ array::array( allocator::Buffer data, - const std::vector& shape, + std::vector shape, Dtype dtype, deleter_t deleter) - : array_desc_(std::make_shared(shape, dtype)) { + : array_desc_(std::make_shared(std::move(shape), dtype)) { set_data(data, deleter); } @@ -181,31 +170,16 @@ void array::move_shared_buffer(array other) { move_shared_buffer(other, other.strides(), other.flags(), other.data_size()); } -array::ArrayDesc::ArrayDesc(const std::vector& shape, Dtype dtype) - : shape(shape), dtype(dtype) { - std::tie(size, strides) = cum_prod(shape); -} - -array::ArrayDesc::ArrayDesc( - const std::vector& shape, - Dtype dtype, - std::shared_ptr primitive, - const std::vector& inputs) - : shape(shape), - dtype(dtype), - primitive(std::move(primitive)), - inputs(inputs) { +array::ArrayDesc::ArrayDesc(std::vector shape, Dtype dtype) + : shape(std::move(shape)), dtype(dtype) { std::tie(size, strides) = cum_prod(this->shape); - for (auto& in : this->inputs) { - is_tracer |= in.is_tracer(); - } } array::ArrayDesc::ArrayDesc( - std::vector&& shape, + std::vector shape, Dtype dtype, std::shared_ptr primitive, - std::vector&& inputs) + std::vector inputs) : shape(std::move(shape)), dtype(dtype), primitive(std::move(primitive)), diff --git a/mlx/array.h b/mlx/array.h index 1f04d8ff83..c5a76d9686 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -32,7 +32,7 @@ class array { template array( It data, - const std::vector& shape, + std::vector shape, Dtype dtype = TypeToDtype::value_type>()); @@ -48,13 +48,13 @@ class array { template array( std::initializer_list data, - const std::vector& shape, + std::vector shape, Dtype dtype = TypeToDtype()); /* Build an array from a buffer */ array( allocator::Buffer data, - const std::vector& shape, + std::vector shape, Dtype dtype, deleter_t deleter = allocator::free); @@ -173,17 +173,11 @@ class array { * API may change. */ - array( - const std::vector& shape, - Dtype dtype, - std::shared_ptr primitive, - const std::vector& inputs); - array( std::vector shape, Dtype dtype, std::shared_ptr primitive, - std::vector&& inputs); + std::vector inputs); static std::vector make_arrays( const std::vector>& shapes, @@ -391,19 +385,13 @@ class array { // The arrays position in the output list uint32_t position{0}; - explicit ArrayDesc(const std::vector& shape, Dtype dtype); + explicit ArrayDesc(std::vector shape, Dtype dtype); explicit ArrayDesc( - const std::vector& shape, + std::vector shape, Dtype dtype, std::shared_ptr primitive, - const std::vector& inputs); - - explicit ArrayDesc( - std::vector&& shape, - Dtype dtype, - std::shared_ptr primitive, - std::vector&& inputs); + std::vector inputs); }; // The ArrayDesc contains the details of the materialized array including the @@ -422,9 +410,9 @@ array::array(T val, Dtype dtype /* = TypeToDtype() */) template array::array( It data, - const std::vector& shape, + std::vector shape, Dtype dtype /* = TypeToDtype::value_type>() */) : - array_desc_(std::make_shared(shape, dtype)) { + array_desc_(std::make_shared(std::move(shape), dtype)) { init(data); } @@ -441,9 +429,9 @@ array::array( template array::array( std::initializer_list data, - const std::vector& shape, + std::vector shape, Dtype dtype /* = TypeToDtype() */) - : array_desc_(std::make_shared(shape, dtype)) { + : array_desc_(std::make_shared(std::move(shape), dtype)) { if (data.size() != size()) { throw std::invalid_argument( "Data size and provided shape mismatch in array construction.");