diff --git a/mlx/array.cpp b/mlx/array.cpp index 864c6973df..c2982d05c7 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -36,22 +36,11 @@ array::array(const std::complex& val, Dtype dtype /* = complex64 */) init(&cval); } -array::array( - const std::vector& shape, - Dtype dtype, - std::shared_ptr primitive, - const std::vector& inputs) - : array_desc_(std::make_shared( - shape, - dtype, - std::move(primitive), - inputs)) {} - array::array( std::vector shape, Dtype dtype, std::shared_ptr primitive, - std::vector&& inputs) + std::vector inputs) : array_desc_(std::make_shared( std::move(shape), dtype, @@ -92,10 +81,10 @@ array::array(std::initializer_list data, Dtype dtype) /* Build an array from a shared buffer */ array::array( allocator::Buffer data, - const std::vector& shape, + std::vector shape, Dtype dtype, deleter_t deleter) - : array_desc_(std::make_shared(shape, dtype)) { + : array_desc_(std::make_shared(std::move(shape), dtype)) { set_data(data, deleter); } @@ -181,31 +170,16 @@ void array::move_shared_buffer(array other) { move_shared_buffer(other, other.strides(), other.flags(), other.data_size()); } -array::ArrayDesc::ArrayDesc(const std::vector& shape, Dtype dtype) - : shape(shape), dtype(dtype) { - std::tie(size, strides) = cum_prod(shape); -} - -array::ArrayDesc::ArrayDesc( - const std::vector& shape, - Dtype dtype, - std::shared_ptr primitive, - const std::vector& inputs) - : shape(shape), - dtype(dtype), - primitive(std::move(primitive)), - inputs(inputs) { +array::ArrayDesc::ArrayDesc(std::vector shape, Dtype dtype) + : shape(std::move(shape)), dtype(dtype) { std::tie(size, strides) = cum_prod(this->shape); - for (auto& in : this->inputs) { - is_tracer |= in.is_tracer(); - } } array::ArrayDesc::ArrayDesc( - std::vector&& shape, + std::vector shape, Dtype dtype, std::shared_ptr primitive, - std::vector&& inputs) + std::vector inputs) : shape(std::move(shape)), dtype(dtype), primitive(std::move(primitive)), diff --git a/mlx/array.h b/mlx/array.h index 1f04d8ff83..c5a76d9686 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -32,7 +32,7 @@ class array { template array( It data, - const std::vector& shape, + std::vector shape, Dtype dtype = TypeToDtype::value_type>()); @@ -48,13 +48,13 @@ class array { template array( std::initializer_list data, - const std::vector& shape, + std::vector shape, Dtype dtype = TypeToDtype()); /* Build an array from a buffer */ array( allocator::Buffer data, - const std::vector& shape, + std::vector shape, Dtype dtype, deleter_t deleter = allocator::free); @@ -173,17 +173,11 @@ class array { * API may change. */ - array( - const std::vector& shape, - Dtype dtype, - std::shared_ptr primitive, - const std::vector& inputs); - array( std::vector shape, Dtype dtype, std::shared_ptr primitive, - std::vector&& inputs); + std::vector inputs); static std::vector make_arrays( const std::vector>& shapes, @@ -391,19 +385,13 @@ class array { // The arrays position in the output list uint32_t position{0}; - explicit ArrayDesc(const std::vector& shape, Dtype dtype); + explicit ArrayDesc(std::vector shape, Dtype dtype); explicit ArrayDesc( - const std::vector& shape, + std::vector shape, Dtype dtype, std::shared_ptr primitive, - const std::vector& inputs); - - explicit ArrayDesc( - std::vector&& shape, - Dtype dtype, - std::shared_ptr primitive, - std::vector&& inputs); + std::vector inputs); }; // The ArrayDesc contains the details of the materialized array including the @@ -422,9 +410,9 @@ array::array(T val, Dtype dtype /* = TypeToDtype() */) template array::array( It data, - const std::vector& shape, + std::vector shape, Dtype dtype /* = TypeToDtype::value_type>() */) : - array_desc_(std::make_shared(shape, dtype)) { + array_desc_(std::make_shared(std::move(shape), dtype)) { init(data); } @@ -441,9 +429,9 @@ array::array( template array::array( std::initializer_list data, - const std::vector& shape, + std::vector shape, Dtype dtype /* = TypeToDtype() */) - : array_desc_(std::make_shared(shape, dtype)) { + : array_desc_(std::make_shared(std::move(shape), dtype)) { if (data.size() != size()) { throw std::invalid_argument( "Data size and provided shape mismatch in array construction.");