diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 2dfe33a05..c1e9c8e0f 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -158,50 +158,64 @@ array linspace( to_stream(s)); } -array astype(const array& a, Dtype dtype, StreamOrDevice s /* = {} */) { +array astype(array a, Dtype dtype, StreamOrDevice s /* = {} */) { if (dtype == a.dtype()) { - return a; + return std::move(a); } + auto copied_shape = a.shape(); // |a| will be moved return array( - a.shape(), dtype, std::make_shared(to_stream(s), dtype), {a}); + std::move(copied_shape), + dtype, + std::make_shared(to_stream(s), dtype), + {std::move(a)}); } array as_strided( - const array& a, + array a, std::vector shape, std::vector strides, size_t offset, StreamOrDevice s /* = {} */) { - // Force the input array to be contiguous - auto x = reshape(a, {-1}, s); + auto copied_shape = shape; // |shape| will be moved + auto dtype = a.dtype(); // |a| will be moved return array( - shape, - a.dtype(), - std::make_shared(to_stream(s), shape, strides, offset), - {x}); + std::move(copied_shape), + dtype, + std::make_shared( + to_stream(s), std::move(shape), std::move(strides), offset), + // Force the input array to be contiguous. + {reshape(std::move(a), {-1}, s)}); } -array copy(const array& a, StreamOrDevice s /* = {} */) { - return array(a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a}); +array copy(array a, StreamOrDevice s /* = {} */) { + auto copied_shape = a.shape(); // |a| will be moved + auto dtype = a.dtype(); + return array( + std::move(copied_shape), + dtype, + std::make_shared(to_stream(s)), + {std::move(a)}); } array full( - const std::vector& shape, - const array& vals, + std::vector shape, + array vals, Dtype dtype, StreamOrDevice s /* = {} */) { - if (std::any_of(shape.begin(), shape.end(), [](auto i) { return i < 0; })) { + if (std::any_of(shape.begin(), shape.end(), [](int i) { return i < 0; })) { throw std::invalid_argument("[full] Negative dimensions not allowed."); } - auto in = broadcast_to(astype(vals, dtype, s), shape, s); - return array(shape, dtype, std::make_shared(to_stream(s)), {in}); + auto copied_shape = shape; // |shape| will be moved + return array( + std::move(copied_shape), + dtype, + std::make_shared(to_stream(s)), + {broadcast_to(astype(std::move(vals), dtype, s), std::move(shape), s)}); } -array full( - const std::vector& shape, - const array& vals, - StreamOrDevice s /* = {} */) { - return full(shape, vals, vals.dtype(), to_stream(s)); +array full(std::vector shape, array vals, StreamOrDevice s /* = {} */) { + auto dtype = vals.dtype(); // |vals| will be moved + return full(std::move(shape), std::move(vals), dtype, to_stream(s)); } array zeros( diff --git a/mlx/ops.h b/mlx/ops.h index ff9963ed7..69059ed75 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -41,40 +41,33 @@ array linspace( StreamOrDevice s = {}); /** Convert an array to the given data type. */ -array astype(const array& a, Dtype dtype, StreamOrDevice s = {}); +array astype(array a, Dtype dtype, StreamOrDevice s = {}); /** Create a view of an array with the given shape and strides. */ array as_strided( - const array& a, + array a, std::vector shape, std::vector strides, size_t offset, StreamOrDevice s = {}); /** Copy another array. */ -array copy(const array& a, StreamOrDevice s = {}); +array copy(array a, StreamOrDevice s = {}); /** Fill an array of the given shape with the given value(s). */ array full( - const std::vector& shape, - const array& vals, + std::vector shape, + array vals, Dtype dtype, StreamOrDevice s = {}); -array full( - const std::vector& shape, - const array& vals, - StreamOrDevice s = {}); +array full(std::vector shape, array vals, StreamOrDevice s = {}); template -array full( - const std::vector& shape, - T val, - Dtype dtype, - StreamOrDevice s = {}) { - return full(shape, array(val, dtype), to_stream(s)); +array full(std::vector shape, T val, Dtype dtype, StreamOrDevice s = {}) { + return full(std::move(shape), array(val, dtype), to_stream(s)); } template -array full(const std::vector& shape, T val, StreamOrDevice s = {}) { - return full(shape, array(val), to_stream(s)); +array full(std::vector shape, T val, StreamOrDevice s = {}) { + return full(std::move(shape), array(val), to_stream(s)); } /** Fill an array of the given shape with zeros. */ diff --git a/mlx/primitives.h b/mlx/primitives.h index 78e759d62..9e36afbd1 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -420,12 +420,12 @@ class AsStrided : public UnaryPrimitive { public: explicit AsStrided( Stream stream, - const std::vector& shape, - const std::vector& strides, + std::vector shape, + std::vector strides, size_t offset) : UnaryPrimitive(stream), - shape_(shape), - strides_(strides), + shape_(std::move(shape)), + strides_(std::move(strides)), offset_(offset){}; void eval_cpu(const std::vector& inputs, array& out) override;