From 2b9c24c51799f182c6f7f21e5a12e9e17a5e3544 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 6 Dec 2024 16:50:08 -0800 Subject: [PATCH] works --- mlx/ops.cpp | 21 +++++++---- mlx/primitives.cpp | 85 ++++++++++++++++++++++++++++++++++----------- mlx/primitives.h | 15 ++++++-- tests/ops_tests.cpp | 39 +++++++++++++++++++++ 4 files changed, 130 insertions(+), 30 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 0d28266b6..b8eb799cb 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -419,7 +419,7 @@ array dynamic_reshape( // - At most a.ndim() unique letters // - Only valid characters in string (alphabet, integer, *, /) bool infer_dim = false; - std::unordered_set dims; + std::unordered_map char_to_dim; for (auto& e : expressions) { if (auto pv = std::get_if(&e); pv) { if (*pv == -1) { @@ -435,7 +435,7 @@ array dynamic_reshape( for (auto c : s) { if (isalpha(c)) { has_alpha = true; - dims.insert(c); + char_to_dim.insert({c, char_to_dim.size()}); } else if (!isdigit(c) && c != '*' && c != '/') { std::ostringstream msg; msg << "[dynamic_reshape] Invalid character in string expression \"" @@ -449,20 +449,29 @@ array dynamic_reshape( << "one alphabetic character but got: \"" << s << "\"."; throw std::invalid_argument(msg.str()); } + if (!isdigit(s[0]) && !isalpha(s[0]) && !isdigit(s.back()) && + !isalpha(s.back())) { + std::ostringstream msg; + msg << "[dynamic_reshape] String expression must start and end with " + << "integer or letter but got: \"" << s << "\"."; + throw std::invalid_argument(msg.str()); + } } } - if (dims.size() >= a.ndim()) { + if (char_to_dim.size() > a.ndim()) { std::ostringstream msg; - msg << "[dynamic_reshape] Expressions contain " << dims.size() + msg << "[dynamic_reshape] Expressions contain " << char_to_dim.size() << " abstract dimensions for array with only " << a.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } - auto output_shape = Shape{}; // Reshape::shape_from_expression(a, expression); + auto output_shape = + Reshape::shape_from_expressions(expressions, char_to_dim, a); return array( std::move(output_shape), a.dtype(), - std::make_shared(to_stream(s), std::move(expressions)), + std::make_shared( + to_stream(s), std::move(expressions), std::move(char_to_dim)), {a}); } diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 6f5808152..ac455d120 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2858,24 +2858,21 @@ std::vector Reshape::jvp( bool Reshape::is_equivalent(const Primitive& other) const { const Reshape& r_other = static_cast(other); - if (!expression_.empty()) { - return expression_ == r_other.expression_; + if (!expressions_.empty()) { + return expressions_ == r_other.expressions_; } return shape_ == r_other.shape_; } -std::vector Reshape::output_shapes(const std::vector& inputs) { - // Only allowed to dynamically reshape when the shape is {} - if (expression_.empty() && !shape_.empty()) { - throw std::invalid_argument( - "[Reshape::output_shapes] Unable to infer output shape."); - } - - auto& in = inputs[0]; - Shape output_shape(expression_.size()); +Shape Reshape::shape_from_expressions( + const std::vector>& expressions, + const std::unordered_map& char_to_dim, + const array& in) { + Shape output_shape(expressions.size()); int dim_to_infer = -1; - for (int i = 0, j = 0; i < expression_.size(); ++i) { - auto& e = expression_[i]; + uint64_t size = 1; + for (int i = 0; i < expressions.size(); ++i) { + auto& e = expressions[i]; if (auto pv = std::get_if(&e); pv) { if (*pv == -1) { dim_to_infer = i; @@ -2885,20 +2882,66 @@ std::vector Reshape::output_shapes(const std::vector& inputs) { } } else { auto& s = std::get(e); - output_shape[i] = in.shape()[j++]; + if (s.size() == 1) { + output_shape[i] = in.shape()[char_to_dim.at(s[0])]; + } else { + int d; + size_t loc = 0; + char op = 0; + while (loc < s.size()) { + int res; + if (std::isdigit(s[loc])) { + char* p; + res = std::strtol(s.c_str() + loc, &p, 10); + loc = (p - s.c_str()); + } else if (std::isalpha(s[loc])) { + res = in.shape()[char_to_dim.at(s[loc++])]; + } else if (s[loc] == '*' || s[loc] == '/') { + op = s[loc++]; + continue; + } + + if (op == '*') { + d *= res; + } else if (op == '/') { + d /= res; + } else { + d = res; + } + } + output_shape[i] = d; + } } + size *= output_shape[i]; } if (dim_to_infer >= 0) { - uint64_t output_size = 1; - for (int i = 0; i < output_shape.size(); ++i) { - if (i != dim_to_infer) { - output_size *= output_shape[i]; - } + if (size == 0) { + throw std::invalid_argument( + "[dynamic_reshape] Cannot infer the shape of an empty array."); } - output_shape[dim_to_infer] = in.size() / output_size; + auto d = in.size() / size; + output_shape[dim_to_infer] = d; + size *= d; } - return {std::move(output_shape)}; + + if (in.size() != size) { + std::ostringstream msg; + msg << "[dynamic_reshape] Cannot reshape array of size " << in.size() + << " into shape " << output_shape << "."; + throw std::invalid_argument(msg.str()); + } + + return output_shape; +} + +std::vector Reshape::output_shapes(const std::vector& inputs) { + // Only allowed to dynamically reshape when the shape is {} + if (expressions_.empty() && !shape_.empty()) { + throw std::invalid_argument( + "[Reshape::output_shapes] Unable to infer output shape."); + } + return {shape_from_expressions(expressions_, char_to_dim_, inputs[0])}; } std::vector Reduce::vjp( diff --git a/mlx/primitives.h b/mlx/primitives.h index e9b0a743a..c37d0cfaa 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1611,8 +1611,11 @@ class Reshape : public UnaryPrimitive { explicit Reshape( Stream stream, - std::vector> expression) - : UnaryPrimitive(stream), expression_(std::move(expression)) {} + std::vector> expressions, + std::unordered_map char_to_dim) + : UnaryPrimitive(stream), + expressions_(std::move(expressions)), + char_to_dim_(std::move(char_to_dim)) {} void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1623,9 +1626,15 @@ class Reshape : public UnaryPrimitive { bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; + static Shape shape_from_expressions( + const std::vector>& expressions, + const std::unordered_map& char_to_dim, + const array& in); + private: Shape shape_; - std::vector> expression_; + std::vector> expressions_; + std::unordered_map char_to_dim_; void eval(const std::vector& inputs, array& out); diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index c2efd0ce0..ef352da8a 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -4,6 +4,7 @@ #define _USE_MATH_DEFINES #include +#include // TODO #include #include "doctest/doctest.h" @@ -3777,6 +3778,10 @@ TEST_CASE("test dynamic reshape") { // Bad character CHECK_THROWS(dynamic_reshape(x, {"&", 1, 1})); + // Malformed + CHECK_THROWS(dynamic_reshape(x, {"+a", 1, 1})); + CHECK_THROWS(dynamic_reshape(x, {"a+", 1, 1})); + // No dim in string CHECK_THROWS(dynamic_reshape(x, {"1", 1, 1})); @@ -3785,7 +3790,41 @@ TEST_CASE("test dynamic reshape") { // Too many dims CHECK_THROWS(dynamic_reshape(x, {"a", "b", "c", "d"})); + CHECK_THROWS(dynamic_reshape(x, {"abcd", 1, 1})); // Too many inferred dims CHECK_THROWS(dynamic_reshape(x, {"a", -1, -1})); + + // Bad sizes + x = zeros({2, 2, 2}); + CHECK_THROWS_AS(dynamic_reshape(x, {7}), std::invalid_argument); + CHECK_THROWS_AS(dynamic_reshape(x, {-1, 7}), std::invalid_argument); + + // Works with empty array + x = array({}); + auto y = dynamic_reshape(x, {0, 0, 0}); + CHECK_EQ(y.shape(), Shape{0, 0, 0}); + CHECK_THROWS_AS(dynamic_reshape(x, {}), std::invalid_argument); + CHECK_THROWS_AS(dynamic_reshape(x, {1}), std::invalid_argument); + y = dynamic_reshape(x, {1, 5, 0}); + CHECK_EQ(y.shape(), Shape{1, 5, 0}); + + x = array({1, 2, 3}); + y = dynamic_reshape(x, {"a", 1, 1}); + CHECK_EQ(y.shape(), Shape{3, 1, 1}); + + x = zeros({2, 2}); + y = dynamic_reshape(x, {"a*b"}); + CHECK_EQ(y.shape(), Shape{4}); + + y = dynamic_reshape(x, {"2*a"}); + CHECK_EQ(y.shape(), Shape{4}); + + x = zeros({2, 20}); + y = dynamic_reshape(x, {"a*20"}); + CHECK_EQ(y.shape(), Shape{40}); + + x = zeros({2, 20}); + y = dynamic_reshape(x, {"a", "b/10", 10}); + CHECK_EQ(y.shape(), Shape{2, 2, 10}); }