diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index a1e99d7c7..6612a01a8 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -468,20 +468,78 @@ void RandomBits::eval(const std::vector& inputs, array& out) { } } +std::pair> Reshape::prepare_reshape( + const array& in, + const array& out) { + // Special case for empty arrays + if (in.size() == 0) { + return {false, out.strides()}; + } + + // Special case for scalars + if (in.ndim() == 0) { + std::vector out_strides(out.ndim(), 0); + return {false, out_strides}; + } + + // Firstly let's collapse all the contiguous dimensions of the input + auto [shape, _strides] = collapse_contiguous_dims(in); + auto& strides = _strides[0]; + + // If shapes fit exactly in the contiguous dims then no copy is necessary so + // let's check. + std::vector out_strides; + bool copy_necessary = false; + int j = 0; + for (int i = 0; i < out.ndim(); i++) { + int N = out.shape(i); + if (j < shape.size() && shape[j] % N == 0) { + shape[j] /= N; + out_strides.push_back(shape[j] * strides[j]); + j += (shape[j] == 1); + } else if (N == 1) { + // i > 0 because otherwise j < shape.size() && shape[j] % 1 == 0 + out_strides.push_back(out_strides.back()); + } else { + copy_necessary = true; + break; + } + } + + return {copy_necessary, out_strides}; +} + +void Reshape::shared_buffer_reshape( + const array& in, + const std::vector& out_strides, + array& out) { + auto flags = in.flags(); + if (flags.contiguous && in.data_size() == in.size()) { + size_t f_stride = 1; + size_t b_stride = 1; + flags.col_contiguous = true; + flags.row_contiguous = true; + for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) { + flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1); + f_stride *= out.shape(i); + flags.row_contiguous &= + (out_strides[ri] == b_stride || out.shape(ri) == 1); + b_stride *= out.shape(ri); + } + } + out.copy_shared_buffer(in, out_strides, flags, in.data_size()); +} + void Reshape::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - if (in.flags().row_contiguous) { - // For row contiguous reshapes: - // - Shallow copy the buffer - // - If reshaping into a vector (all singleton dimensions except one) it - // becomes col contiguous again. - auto flags = in.flags(); - auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); - flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; - out.copy_shared_buffer(in, out.strides(), flags, in.data_size()); - } else { + + auto [copy_necessary, out_strides] = prepare_reshape(in, out); + + if (copy_necessary) { copy(in, out, in.data_size() == 1 ? CopyType::Scalar : CopyType::General); + } else { + shared_buffer_reshape(in, out_strides, out); } } diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index e170f5dfa..8023789dc 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -28,4 +28,70 @@ inline size_t elem_to_loc(int elem, const array& a) { return elem_to_loc(elem, a.shape(), a.strides()); } +// Collapse dims that are contiguous to possibly route to a better kernel +// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1}) +// should return {{2, 4}, {{1, 2}}}. +// +// When multiple arrays are passed they should all have the same shape. The +// collapsed axes are also the same so one shape is returned. +inline std::tuple, std::vector>> +collapse_contiguous_dims( + const std::vector& shape, + const std::vector> strides) { + // Make a vector that has axes separated with -1. Collapse all axes between + // -1. + std::vector to_collapse; + if (shape.size() > 0) { + to_collapse.push_back(0); + for (int i = 1; i < shape.size(); i++) { + bool contiguous = true; + for (const std::vector& st : strides) { + if (st[i] * shape[i] != st[i - 1]) { + contiguous = false; + } + if (!contiguous) { + break; + } + } + if (!contiguous) { + to_collapse.push_back(-1); + } + to_collapse.push_back(i); + } + to_collapse.push_back(-1); + } + + std::vector out_shape; + std::vector> out_strides(strides.size()); + for (int i = 0; i < to_collapse.size(); i++) { + int current_shape = shape[to_collapse[i]]; + while (to_collapse[++i] != -1) { + current_shape *= shape[to_collapse[i]]; + } + out_shape.push_back(current_shape); + for (int j = 0; j < strides.size(); j++) { + const std::vector& st = strides[j]; + out_strides[j].push_back(st[to_collapse[i - 1]]); + } + } + + return std::make_tuple(out_shape, out_strides); +} + +inline std::tuple, std::vector>> +collapse_contiguous_dims(const std::vector& xs) { + std::vector> strides; + for (auto& x : xs) { + strides.emplace_back(x.strides()); + } + return collapse_contiguous_dims(xs[0].shape(), strides); +} + +template +inline std::tuple, std::vector>> +collapse_contiguous_dims(Arrays... xs) { + return collapse_contiguous_dims( + std::vector{std::forward(xs)...}); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index 0879e623d..f59d0582a 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -3,6 +3,7 @@ #include #include "mlx/backend/common/compiled.h" +#include "mlx/backend/common/utils.h" #include "mlx/backend/metal/compiled_preamble.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/utils.h" diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 067011260..83c20db6f 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -805,13 +805,13 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { void Reshape::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - if (in.flags().row_contiguous) { - auto flags = in.flags(); - auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); - flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; - out.copy_shared_buffer(in, out.strides(), flags, in.data_size()); - } else { + + auto [copy_necessary, out_strides] = prepare_reshape(in, out); + + if (copy_necessary) { copy_gpu(in, out, CopyType::General); + } else { + shared_buffer_reshape(in, out_strides, out); } } diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index 363632a30..a8e4cfd44 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -96,72 +96,6 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) { return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]}; } -// Collapse dims that are contiguous to possibly route to a better kernel -// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1}) -// should return {{2, 4}, {{1, 2}}}. -// -// When multiple arrays are passed they should all have the same shape. The -// collapsed axes are also the same so one shape is returned. -std::tuple, std::vector>> -collapse_contiguous_dims( - const std::vector& shape, - const std::vector> strides) { - // Make a vector that has axes separated with -1. Collapse all axes between - // -1. - std::vector to_collapse; - if (shape.size() > 0) { - to_collapse.push_back(0); - for (int i = 1; i < shape.size(); i++) { - bool contiguous = true; - for (const std::vector& st : strides) { - if (st[i] * shape[i] != st[i - 1]) { - contiguous = false; - } - if (!contiguous) { - break; - } - } - if (!contiguous) { - to_collapse.push_back(-1); - } - to_collapse.push_back(i); - } - to_collapse.push_back(-1); - } - - std::vector out_shape; - std::vector> out_strides(strides.size()); - for (int i = 0; i < to_collapse.size(); i++) { - int current_shape = shape[to_collapse[i]]; - while (to_collapse[++i] != -1) { - current_shape *= shape[to_collapse[i]]; - } - out_shape.push_back(current_shape); - for (int j = 0; j < strides.size(); j++) { - const std::vector& st = strides[j]; - out_strides[j].push_back(st[to_collapse[i - 1]]); - } - } - - return std::make_tuple(out_shape, out_strides); -} - -std::tuple, std::vector>> -collapse_contiguous_dims(const std::vector& xs) { - std::vector> strides; - for (auto& x : xs) { - strides.emplace_back(x.strides()); - } - return collapse_contiguous_dims(xs[0].shape(), strides); -} - -template -std::tuple, std::vector>> -collapse_contiguous_dims(Arrays... xs) { - return collapse_contiguous_dims( - std::vector{std::forward(xs)...}); -} - } // namespace } // namespace mlx::core diff --git a/mlx/ops.cpp b/mlx/ops.cpp index ad3efab71..c30a6468a 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1518,9 +1518,11 @@ array min( array argmin(const array& a, bool keepdims, StreamOrDevice s /* = {} */) { int size = a.size(); - auto result = argmin(reshape(a, {size}, s), 0, false, s); + auto result = argmin(reshape(a, {size}, s), 0, true, s); if (keepdims) { result = reshape(result, std::vector(a.shape().size(), 1), s); + } else { + result = squeeze(result, s); } return result; } @@ -1549,9 +1551,11 @@ array argmin( array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) { int size = a.size(); - auto result = argmax(reshape(a, {size}, s), 0, false, s); + auto result = argmax(reshape(a, {size}, s), 0, true, s); if (keepdims) { result = reshape(result, std::vector(a.shape().size(), 1), s); + } else { + result = squeeze(result, s); } return result; } diff --git a/mlx/primitives.h b/mlx/primitives.h index d428fc3ab..aea99eda9 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1359,6 +1359,14 @@ class Reshape : public UnaryPrimitive { std::vector shape_; void eval(const std::vector& inputs, array& out); + + std::pair> prepare_reshape( + const array& in, + const array& out); + void shared_buffer_reshape( + const array& in, + const std::vector& out_strides, + array& out); }; class Reduce : public UnaryPrimitive { diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 953fa7839..c4305cf92 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -56,6 +56,50 @@ TEST_CASE("test reshape") { CHECK_THROWS_AS(reshape(x, {1}), std::invalid_argument); y = reshape(x, {1, 5, 0}); CHECK_EQ(y.shape(), std::vector{1, 5, 0}); + + // Check that reshaping a transposed array doesn't result in a copy + x = reshape(arange(64), {2, 4, 8}); + x.eval(); + CHECK_EQ(x.strides()[0], 32); + CHECK_EQ(x.strides()[1], 8); + CHECK_EQ(x.strides()[2], 1); + y = reshape(transpose(x, {0, 2, 1}), {2, 4, 2, 4}); + y.eval(); + CHECK_EQ(y.strides()[0], 32); + CHECK_EQ(y.strides()[1], 2); + CHECK_EQ(y.strides()[2], 1); + CHECK_EQ(y.strides()[3], 8); + CHECK_EQ(x.data(), y.data()); + + // Split transposed (2, 8, 4) -> (2, 8, 2, 2) + y = reshape(transpose(x, {0, 2, 1}), {2, 8, 2, 2}); + y.eval(); + CHECK_EQ(y.strides()[0], 32); + CHECK_EQ(y.strides()[1], 1); + CHECK_EQ(y.strides()[2], 16); + CHECK_EQ(y.strides()[3], 8); + CHECK_EQ(x.data(), y.data()); + + // Split transposed (2, 8, 4) -> (2, 8, 2, 1, 2) + y = reshape(transpose(x, {0, 2, 1}), {2, 8, 2, 1, 2}); + y.eval(); + CHECK_EQ(y.strides()[0], 32); + CHECK_EQ(y.strides()[1], 1); + CHECK_EQ(y.strides()[2], 16); + // y.strides()[3] can be anything since y.shape()[3] == 1 + CHECK_EQ(y.strides()[4], 8); + CHECK_EQ(x.data(), y.data()); + + // Split transposed (2, 8, 4) -> (2, 8, 2, 1, 2, 1) + y = reshape(transpose(x, {0, 2, 1}), {2, 8, 2, 1, 2, 1}); + y.eval(); + CHECK_EQ(y.strides()[0], 32); + CHECK_EQ(y.strides()[1], 1); + CHECK_EQ(y.strides()[2], 16); + // y.strides()[3] can be anything since y.shape()[3] == 1 + CHECK_EQ(y.strides()[4], 8); + // y.strides()[5] can be anything since y.shape()[5] == 1 + CHECK_EQ(x.data(), y.data()); } TEST_CASE("test flatten") {