From ee59d50293973c526f16dc93c0d9da7771f51cac Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 6 Dec 2024 12:09:08 -0800 Subject: [PATCH] try dynamic reshape --- mlx/ops.cpp | 63 ++++++++++++++++++++++++++++++++++++++++ mlx/ops.h | 6 ++++ mlx/primitives.cpp | 40 +++++++++++++++++++++++++ mlx/primitives.h | 7 +++++ python/src/ops.cpp | 23 +++++++++++++++ python/tests/test_ops.py | 5 ++++ tests/ops_tests.cpp | 20 +++++++++++++ 7 files changed, 164 insertions(+) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index af772ce611..0d28266b6c 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -403,6 +403,69 @@ array reshape(const array& a, Shape shape, StreamOrDevice s /* = {} */) { return array(std::move(shape), a.dtype(), std::move(p), {a}); } +// Variant of string and int for the expressions +array dynamic_reshape( + const array& a, + std::vector> expressions, + StreamOrDevice s /* = {} */) { + // Reshape to scalar is not dynamic + if (expressions.empty()) { + return reshape(a, {}, s); + } + + // Validate expressions: + // - At most one item in expressions is -1 + // - Any string expression should have a letter + // - At most a.ndim() unique letters + // - Only valid characters in string (alphabet, integer, *, /) + bool infer_dim = false; + std::unordered_set dims; + for (auto& e : expressions) { + if (auto pv = std::get_if(&e); pv) { + if (*pv == -1) { + if (infer_dim) { + throw std::invalid_argument( + "[dynamic_reshape] Cannot infer more than one dimension."); + } + infer_dim = true; + } + } else { + auto& s = std::get(e); + bool has_alpha = false; + for (auto c : s) { + if (isalpha(c)) { + has_alpha = true; + dims.insert(c); + } else if (!isdigit(c) && c != '*' && c != '/') { + std::ostringstream msg; + msg << "[dynamic_reshape] Invalid character in string expression \"" + << s << "\"."; + throw std::invalid_argument(msg.str()); + } + } + if (!has_alpha) { + std::ostringstream msg; + msg << "[dynamic_reshape] String expression must contain at least " + << "one alphabetic character but got: \"" << s << "\"."; + throw std::invalid_argument(msg.str()); + } + } + } + if (dims.size() >= a.ndim()) { + std::ostringstream msg; + msg << "[dynamic_reshape] Expressions contain " << dims.size() + << " abstract dimensions for array with only " << a.ndim() + << " dimensions."; + throw std::invalid_argument(msg.str()); + } + auto output_shape = Shape{}; // Reshape::shape_from_expression(a, expression); + return array( + std::move(output_shape), + a.dtype(), + std::make_shared(to_stream(s), std::move(expressions)), + {a}); +} + array flatten( const array& a, int start_axis, diff --git a/mlx/ops.h b/mlx/ops.h index 7e24b5820d..e09386c11c 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -117,6 +117,12 @@ array triu(array x, int k = 0, StreamOrDevice s = {}); /** Reshape an array to the given shape. */ array reshape(const array& a, Shape shape, StreamOrDevice s = {}); +/** Dynamically reshape an array based on the given expressions. */ +array dynamic_reshape( + const array& a, + std::vector> expressions, + StreamOrDevice s = {}); + /** Flatten the dimensions in the range `[start_axis, end_axis]` . */ array flatten( const array& a, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index ab1c1f03be..6f58081523 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2858,9 +2858,49 @@ std::vector Reshape::jvp( bool Reshape::is_equivalent(const Primitive& other) const { const Reshape& r_other = static_cast(other); + if (!expression_.empty()) { + return expression_ == r_other.expression_; + } return shape_ == r_other.shape_; } +std::vector Reshape::output_shapes(const std::vector& inputs) { + // Only allowed to dynamically reshape when the shape is {} + if (expression_.empty() && !shape_.empty()) { + throw std::invalid_argument( + "[Reshape::output_shapes] Unable to infer output shape."); + } + + auto& in = inputs[0]; + Shape output_shape(expression_.size()); + int dim_to_infer = -1; + for (int i = 0, j = 0; i < expression_.size(); ++i) { + auto& e = expression_[i]; + if (auto pv = std::get_if(&e); pv) { + if (*pv == -1) { + dim_to_infer = i; + continue; + } else { + output_shape[i] = *pv; + } + } else { + auto& s = std::get(e); + output_shape[i] = in.shape()[j++]; + } + } + + if (dim_to_infer >= 0) { + uint64_t output_size = 1; + for (int i = 0; i < output_shape.size(); ++i) { + if (i != dim_to_infer) { + output_size *= output_shape[i]; + } + } + output_shape[dim_to_infer] = in.size() / output_size; + } + return {std::move(output_shape)}; +} + std::vector Reduce::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index a166f164c3..e9b0a743a3 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1609,6 +1609,11 @@ class Reshape : public UnaryPrimitive { explicit Reshape(Stream stream, const Shape& shape) : UnaryPrimitive(stream), shape_(shape) {} + explicit Reshape( + Stream stream, + std::vector> expression) + : UnaryPrimitive(stream), expression_(std::move(expression)) {} + void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1616,9 +1621,11 @@ class Reshape : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(Reshape) bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; private: Shape shape_; + std::vector> expression_; void eval(const std::vector& inputs, array& out); diff --git a/python/src/ops.cpp b/python/src/ops.cpp index eb69b2659e..1386bc4868 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4880,4 +4880,27 @@ void init_ops(nb::module_& m) { Returns: array: The imaginary part of ``a``. )pbdoc"); + m.def( + "dynamic_reshape", + &dynamic_reshape, + nb::arg(), + "expression"_a, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def dynamic_reshape(a: array, /, expression: Sequence[Union[int, str]], *, stream: " + "Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Dynamically reshape an array based on the given expression. + + Args: + a (array): Input array. + expression (tuple(int or str)): The expression which determines the + output shape. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: The reshaped array. + )pbdoc"); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 54a3cf8c43..dcf81d1975 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2713,6 +2713,11 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(mx.imag(z).dtype, mx.float32) self.assertTrue(mx.array_equal(mx.imag(z), y)) + def test_dynamic_reshape(self): + a = mx.array(1)[None, None] + a = mx.dynamic_reshape(a, ()) + self.assertEqual(a.shape, ()) + if __name__ == "__main__": unittest.main() diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 545f5e24c2..c2efd0ce03 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -3769,3 +3769,23 @@ TEST_CASE("test contiguous") { CHECK(x.flags().col_contiguous); CHECK_EQ(x.strides(), decltype(x.strides()){1, 2}); } + +TEST_CASE("test dynamic reshape") { + auto x = array({1}, {1, 1, 1}); + CHECK_EQ(dynamic_reshape(x, {}).shape(), Shape{}); + + // Bad character + CHECK_THROWS(dynamic_reshape(x, {"&", 1, 1})); + + // No dim in string + CHECK_THROWS(dynamic_reshape(x, {"1", 1, 1})); + + // Too many dims + CHECK_THROWS(dynamic_reshape(x, {"abcd", 1, 1})); + + // Too many dims + CHECK_THROWS(dynamic_reshape(x, {"a", "b", "c", "d"})); + + // Too many inferred dims + CHECK_THROWS(dynamic_reshape(x, {"a", -1, -1})); +}