diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index 9893b51cb..00338ef88 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -159,6 +159,17 @@ void Conjugate::eval(const std::vector& inputs, array& out) { } } +void Contiguous::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + if (in.flags().row_contiguous || + (allow_col_major_ && in.flags().col_contiguous)) { + out.copy_shared_buffer(in); + } else { + copy(in, out, CopyType::General); + } +} + void Cos::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index f788435a2..9d8d3f942 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -170,6 +170,17 @@ void Concatenate::eval_gpu(const std::vector& inputs, array& out) { concatenate_gpu(inputs, out, axis_, stream()); } +void Contiguous::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + if (in.flags().row_contiguous || + (allow_col_major_ && in.flags().col_contiguous)) { + move_or_copy(in, out); + } else { + copy_gpu(in, out, CopyType::General); + } +} + void Copy::eval_gpu(const std::vector& inputs, array& out) { eval(inputs, out); } diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index 9afeaec8b..b2a83b997 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -39,6 +39,7 @@ NO_CPU(Ceil) NO_CPU(Cholesky) NO_CPU(Concatenate) NO_CPU(Conjugate) +NO_CPU(Contiguous) NO_CPU(Convolution) NO_CPU(Copy) NO_CPU(Cos) diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index aaee51d83..98c89037e 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -40,6 +40,7 @@ NO_GPU(Ceil) NO_GPU_MULTI(Compiled) NO_GPU(Concatenate) NO_GPU(Conjugate) +NO_GPU(Contiguous) NO_GPU(Convolution) NO_GPU(Copy) NO_GPU(Cos) diff --git a/mlx/io/load.cpp b/mlx/io/load.cpp index f2e6f85bd..d5ac38518 100644 --- a/mlx/io/load.cpp +++ b/mlx/io/load.cpp @@ -107,22 +107,13 @@ void save(std::shared_ptr out_stream, array a) { //////////////////////////////////////////////////////// // Check array + a = contiguous(a, true); a.eval(); if (a.nbytes() == 0) { throw std::invalid_argument("[save] cannot serialize an empty array"); } - if (!(a.flags().row_contiguous || a.flags().col_contiguous)) { - a = reshape(flatten(a), a.shape()); - a.eval(); - } - // Check once more in-case the above ops change - if (!(a.flags().row_contiguous || a.flags().col_contiguous)) { - throw std::invalid_argument( - "[save] can only serialize row or col contiguous arrays"); - } - //////////////////////////////////////////////////////// // Check file if (!out_stream->good() || !out_stream->is_open()) { diff --git a/mlx/io/safetensors.cpp b/mlx/io/safetensors.cpp index 5c4854186..0e5d3f5a1 100644 --- a/mlx/io/safetensors.cpp +++ b/mlx/io/safetensors.cpp @@ -179,8 +179,9 @@ void save_safetensors( { std::vector to_eval; to_eval.reserve(a.size()); - for (auto& [_, arr] : a) { - to_eval.push_back(arr); + for (auto& p : a) { + p.second = contiguous(p.second); + to_eval.push_back(p.second); } eval(std::move(to_eval)); } @@ -192,19 +193,6 @@ void save_safetensors( "[save_safetensors] cannot serialize an empty array key: " + key); } - // Try to make it row contiguous - if (!arr.flags().row_contiguous) { - arr = reshape(flatten(arr), arr.shape()); - arr.eval(); - } - - // Has to be row-major now but, check one more time in case - // any of the above change in the future - if (!arr.flags().row_contiguous) { - throw std::invalid_argument( - "[save_safetensors] can only serialize row-major arrays"); - } - json child; child["dtype"] = dtype_to_safetensor_str(arr.dtype()); child["shape"] = arr.shape(); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 4aef90639..4193b08e0 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4700,4 +4700,15 @@ array imag(const array& a, StreamOrDevice s /* = {} */) { return array(a.shape(), float32, std::make_shared(to_stream(s)), {a}); } +array contiguous( + const array& a, + bool allow_col_major /* = false */, + StreamOrDevice s /* = {} */) { + return array( + a.shape(), + a.dtype(), + std::make_shared(to_stream(s), allow_col_major), + {a}); +} + } // namespace mlx::core diff --git a/mlx/ops.h b/mlx/ops.h index c2ea9438c..fdceeed0d 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1486,6 +1486,12 @@ array real(const array& a, StreamOrDevice s = {}); /* The imaginary part of a complex array. */ array imag(const array& a, StreamOrDevice s = {}); +/* Ensure the array's underlying memory is contiguous. */ +array contiguous( + const array& a, + bool allow_col_major = false, + StreamOrDevice s = {}); + /** @} */ } // namespace mlx::core diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 743b695dd..9d9ecd588 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -889,6 +889,32 @@ std::pair, std::vector> Conjugate::vmap( return {{conjugate(inputs[0], stream())}, axes}; } +std::vector Contiguous::vjp( + const std::vector&, + const std::vector& cotangents, + const std::vector&, + const std::vector&) { + return {cotangents}; +} + +std::vector Contiguous::jvp( + const std::vector&, + const std::vector& tangents, + const std::vector&) { + return {tangents}; +} + +std::pair, std::vector> Contiguous::vmap( + const std::vector& inputs, + const std::vector& axes) { + return {{contiguous(inputs[0], allow_col_major_, stream())}, axes}; +} + +bool Contiguous::is_equivalent(const Primitive& other) const { + const Contiguous& c_other = static_cast(other); + return allow_col_major_ == c_other.allow_col_major_; +} + array conv_weight_backward_patches( const array& in, const array& wt, diff --git a/mlx/primitives.h b/mlx/primitives.h index f2b5bab7c..13022db24 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -639,6 +639,25 @@ class Conjugate : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; +class Contiguous : public UnaryPrimitive { + public: + explicit Contiguous(Stream stream, bool allow_col_major) + : UnaryPrimitive(stream), allow_col_major_(allow_col_major) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_PRINT(Contiguous) + DEFINE_INPUT_OUTPUT_SHAPE() + + bool is_equivalent(const Primitive& other) const override; + + private: + bool allow_col_major_; +}; + class Convolution : public UnaryPrimitive { public: explicit Convolution( diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 6bae16fad..4b32d5794 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -3747,3 +3747,17 @@ TEST_CASE("test roll") { CHECK(array_equal(y, array({8, 9, 5, 6, 7, 3, 4, 0, 1, 2}, {2, 5})) .item()); } + +TEST_CASE("test contiguous") { + auto x = array({1, 2, 3}); + x = contiguous(broadcast_to(x, {2, 2, 3})); + eval(x); + CHECK(x.flags().row_contiguous); + CHECK_EQ(x.strides(), decltype(x.strides()){6, 3, 1}); + + x = array({1, 2, 1, 2}, {2, 2}); + x = contiguous(transpose(x), true); + eval(x); + CHECK(x.flags().col_contiguous); + CHECK_EQ(x.strides(), decltype(x.strides()){1, 2}); +}