try dynamic reshape

This commit is contained in:
Awni Hannun 2024-12-06 12:09:08 -08:00
parent 40c62c1321
commit ee59d50293
7 changed files with 164 additions and 0 deletions

View File

@ -403,6 +403,69 @@ array reshape(const array& a, Shape shape, StreamOrDevice s /* = {} */) {
return array(std::move(shape), a.dtype(), std::move(p), {a});
}
// Variant of string and int for the expressions
array dynamic_reshape(
const array& a,
std::vector<std::variant<int, std::string>> expressions,
StreamOrDevice s /* = {} */) {
// Reshape to scalar is not dynamic
if (expressions.empty()) {
return reshape(a, {}, s);
}
// Validate expressions:
// - At most one item in expressions is -1
// - Any string expression should have a letter
// - At most a.ndim() unique letters
// - Only valid characters in string (alphabet, integer, *, /)
bool infer_dim = false;
std::unordered_set<char> dims;
for (auto& e : expressions) {
if (auto pv = std::get_if<int>(&e); pv) {
if (*pv == -1) {
if (infer_dim) {
throw std::invalid_argument(
"[dynamic_reshape] Cannot infer more than one dimension.");
}
infer_dim = true;
}
} else {
auto& s = std::get<std::string>(e);
bool has_alpha = false;
for (auto c : s) {
if (isalpha(c)) {
has_alpha = true;
dims.insert(c);
} else if (!isdigit(c) && c != '*' && c != '/') {
std::ostringstream msg;
msg << "[dynamic_reshape] Invalid character in string expression \""
<< s << "\".";
throw std::invalid_argument(msg.str());
}
}
if (!has_alpha) {
std::ostringstream msg;
msg << "[dynamic_reshape] String expression must contain at least "
<< "one alphabetic character but got: \"" << s << "\".";
throw std::invalid_argument(msg.str());
}
}
}
if (dims.size() >= a.ndim()) {
std::ostringstream msg;
msg << "[dynamic_reshape] Expressions contain " << dims.size()
<< " abstract dimensions for array with only " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
auto output_shape = Shape{}; // Reshape::shape_from_expression(a, expression);
return array(
std::move(output_shape),
a.dtype(),
std::make_shared<Reshape>(to_stream(s), std::move(expressions)),
{a});
}
array flatten(
const array& a,
int start_axis,

View File

@ -117,6 +117,12 @@ array triu(array x, int k = 0, StreamOrDevice s = {});
/** Reshape an array to the given shape. */
array reshape(const array& a, Shape shape, StreamOrDevice s = {});
/** Dynamically reshape an array based on the given expressions. */
array dynamic_reshape(
const array& a,
std::vector<std::variant<int, std::string>> expressions,
StreamOrDevice s = {});
/** Flatten the dimensions in the range `[start_axis, end_axis]` . */
array flatten(
const array& a,

View File

@ -2858,9 +2858,49 @@ std::vector<array> Reshape::jvp(
bool Reshape::is_equivalent(const Primitive& other) const {
const Reshape& r_other = static_cast<const Reshape&>(other);
if (!expression_.empty()) {
return expression_ == r_other.expression_;
}
return shape_ == r_other.shape_;
}
std::vector<Shape> Reshape::output_shapes(const std::vector<array>& inputs) {
// Only allowed to dynamically reshape when the shape is {}
if (expression_.empty() && !shape_.empty()) {
throw std::invalid_argument(
"[Reshape::output_shapes] Unable to infer output shape.");
}
auto& in = inputs[0];
Shape output_shape(expression_.size());
int dim_to_infer = -1;
for (int i = 0, j = 0; i < expression_.size(); ++i) {
auto& e = expression_[i];
if (auto pv = std::get_if<int>(&e); pv) {
if (*pv == -1) {
dim_to_infer = i;
continue;
} else {
output_shape[i] = *pv;
}
} else {
auto& s = std::get<std::string>(e);
output_shape[i] = in.shape()[j++];
}
}
if (dim_to_infer >= 0) {
uint64_t output_size = 1;
for (int i = 0; i < output_shape.size(); ++i) {
if (i != dim_to_infer) {
output_size *= output_shape[i];
}
}
output_shape[dim_to_infer] = in.size() / output_size;
}
return {std::move(output_shape)};
}
std::vector<array> Reduce::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,

View File

@ -1609,6 +1609,11 @@ class Reshape : public UnaryPrimitive {
explicit Reshape(Stream stream, const Shape& shape)
: UnaryPrimitive(stream), shape_(shape) {}
explicit Reshape(
Stream stream,
std::vector<std::variant<int, std::string>> expression)
: UnaryPrimitive(stream), expression_(std::move(expression)) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1616,9 +1621,11 @@ class Reshape : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Reshape)
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
private:
Shape shape_;
std::vector<std::variant<int, std::string>> expression_;
void eval(const std::vector<array>& inputs, array& out);

View File

@ -4880,4 +4880,27 @@ void init_ops(nb::module_& m) {
Returns:
array: The imaginary part of ``a``.
)pbdoc");
m.def(
"dynamic_reshape",
&dynamic_reshape,
nb::arg(),
"expression"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def dynamic_reshape(a: array, /, expression: Sequence[Union[int, str]], *, stream: "
"Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Dynamically reshape an array based on the given expression.
Args:
a (array): Input array.
expression (tuple(int or str)): The expression which determines the
output shape.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The reshaped array.
)pbdoc");
}

View File

@ -2713,6 +2713,11 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.imag(z).dtype, mx.float32)
self.assertTrue(mx.array_equal(mx.imag(z), y))
def test_dynamic_reshape(self):
a = mx.array(1)[None, None]
a = mx.dynamic_reshape(a, ())
self.assertEqual(a.shape, ())
if __name__ == "__main__":
unittest.main()

View File

@ -3769,3 +3769,23 @@ TEST_CASE("test contiguous") {
CHECK(x.flags().col_contiguous);
CHECK_EQ(x.strides(), decltype(x.strides()){1, 2});
}
TEST_CASE("test dynamic reshape") {
auto x = array({1}, {1, 1, 1});
CHECK_EQ(dynamic_reshape(x, {}).shape(), Shape{});
// Bad character
CHECK_THROWS(dynamic_reshape(x, {"&", 1, 1}));
// No dim in string
CHECK_THROWS(dynamic_reshape(x, {"1", 1, 1}));
// Too many dims
CHECK_THROWS(dynamic_reshape(x, {"abcd", 1, 1}));
// Too many dims
CHECK_THROWS(dynamic_reshape(x, {"a", "b", "c", "d"}));
// Too many inferred dims
CHECK_THROWS(dynamic_reshape(x, {"a", -1, -1}));
}