Compare commits

...

3 Commits

Author SHA1 Message Date
Awni Hannun
0c1155faf5 binding + tests 2024-12-09 12:57:36 -08:00
Awni Hannun
2b9c24c517 works 2024-12-09 12:57:36 -08:00
Awni Hannun
ee59d50293 try dynamic reshape 2024-12-09 12:57:36 -08:00
9 changed files with 299 additions and 1 deletions

View File

@@ -80,7 +80,8 @@ bool allows_shapeless(const Primitive& p) {
typeid(p) == typeid(Partition) || typeid(p) == typeid(Select) ||
typeid(p) == typeid(NumberOfElements) || typeid(p) == typeid(Gather) ||
typeid(p) == typeid(Transpose) || typeid(p) == typeid(Concatenate) ||
typeid(p) == typeid(Matmul) || typeid(p) == typeid(QuantizedMatmul) ||
typeid(p) == typeid(Reshape) || typeid(p) == typeid(Matmul) ||
typeid(p) == typeid(QuantizedMatmul) ||
typeid(p) == typeid(fast::AffineQuantize) ||
typeid(p) == typeid(fast::LayerNorm) ||
typeid(p) == typeid(fast::RMSNorm) || typeid(p) == typeid(fast::RoPE) ||

View File

@@ -403,6 +403,78 @@ array reshape(const array& a, Shape shape, StreamOrDevice s /* = {} */) {
return array(std::move(shape), a.dtype(), std::move(p), {a});
}
// Variant of string and int for the expressions
array dynamic_reshape(
const array& a,
std::vector<std::variant<int, std::string>> expressions,
StreamOrDevice s /* = {} */) {
// Reshape to scalar is not dynamic
if (expressions.empty()) {
return reshape(a, {}, s);
}
// Validate expressions:
// - At most one item in expressions is -1
// - Any string expression should have a letter
// - At most a.ndim() unique letters
// - Only valid characters in string (alphabet, integer, *, /)
bool infer_dim = false;
std::unordered_map<char, int> char_to_dim;
for (auto& e : expressions) {
if (auto pv = std::get_if<int>(&e); pv) {
if (*pv == -1) {
if (infer_dim) {
throw std::invalid_argument(
"[dynamic_reshape] Cannot infer more than one dimension.");
}
infer_dim = true;
}
} else {
auto& s = std::get<std::string>(e);
bool has_alpha = false;
for (auto c : s) {
if (isalpha(c)) {
has_alpha = true;
char_to_dim.insert({c, char_to_dim.size()});
} else if (!isdigit(c) && c != '*' && c != '/') {
std::ostringstream msg;
msg << "[dynamic_reshape] Invalid character in string expression \""
<< s << "\".";
throw std::invalid_argument(msg.str());
}
}
if (!has_alpha) {
std::ostringstream msg;
msg << "[dynamic_reshape] String expression must contain at least "
<< "one alphabetic character but got: \"" << s << "\".";
throw std::invalid_argument(msg.str());
}
if (!isdigit(s[0]) && !isalpha(s[0]) && !isdigit(s.back()) &&
!isalpha(s.back())) {
std::ostringstream msg;
msg << "[dynamic_reshape] String expression must start and end with "
<< "integer or letter but got: \"" << s << "\".";
throw std::invalid_argument(msg.str());
}
}
}
if (char_to_dim.size() > a.ndim()) {
std::ostringstream msg;
msg << "[dynamic_reshape] Expressions contain " << char_to_dim.size()
<< " abstract dimensions for array with only " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
auto output_shape =
Reshape::shape_from_expressions(expressions, char_to_dim, a);
return array(
std::move(output_shape),
a.dtype(),
std::make_shared<Reshape>(
to_stream(s), std::move(expressions), std::move(char_to_dim)),
{a});
}
array flatten(
const array& a,
int start_axis,

View File

@@ -117,6 +117,12 @@ array triu(array x, int k = 0, StreamOrDevice s = {});
/** Reshape an array to the given shape. */
array reshape(const array& a, Shape shape, StreamOrDevice s = {});
/** Dynamically reshape an array based on the given expressions. */
array dynamic_reshape(
const array& a,
std::vector<std::variant<int, std::string>> expressions,
StreamOrDevice s = {});
/** Flatten the dimensions in the range `[start_axis, end_axis]` . */
array flatten(
const array& a,

View File

@@ -2858,9 +2858,92 @@ std::vector<array> Reshape::jvp(
bool Reshape::is_equivalent(const Primitive& other) const {
const Reshape& r_other = static_cast<const Reshape&>(other);
if (!expressions_.empty()) {
return expressions_ == r_other.expressions_;
}
return shape_ == r_other.shape_;
}
Shape Reshape::shape_from_expressions(
const std::vector<std::variant<int, std::string>>& expressions,
const std::unordered_map<char, int>& char_to_dim,
const array& in) {
Shape output_shape(expressions.size());
int dim_to_infer = -1;
uint64_t size = 1;
for (int i = 0; i < expressions.size(); ++i) {
auto& e = expressions[i];
if (auto pv = std::get_if<int>(&e); pv) {
if (*pv == -1) {
dim_to_infer = i;
continue;
} else {
output_shape[i] = *pv;
}
} else {
auto& s = std::get<std::string>(e);
if (s.size() == 1) {
output_shape[i] = in.shape()[char_to_dim.at(s[0])];
} else {
int d;
size_t loc = 0;
char op = 0;
while (loc < s.size()) {
int res;
if (std::isdigit(s[loc])) {
char* p;
res = std::strtol(s.c_str() + loc, &p, 10);
loc = (p - s.c_str());
} else if (std::isalpha(s[loc])) {
res = in.shape()[char_to_dim.at(s[loc++])];
} else if (s[loc] == '*' || s[loc] == '/') {
op = s[loc++];
continue;
}
if (op == '*') {
d *= res;
} else if (op == '/') {
d /= res;
} else {
d = res;
}
}
output_shape[i] = d;
}
}
size *= output_shape[i];
}
if (dim_to_infer >= 0) {
if (size == 0) {
throw std::invalid_argument(
"[dynamic_reshape] Cannot infer the shape of an empty array.");
}
auto d = in.size() / size;
output_shape[dim_to_infer] = d;
size *= d;
}
if (in.size() != size) {
std::ostringstream msg;
msg << "[dynamic_reshape] Cannot reshape array of size " << in.size()
<< " into shape " << output_shape << ".";
throw std::invalid_argument(msg.str());
}
return output_shape;
}
std::vector<Shape> Reshape::output_shapes(const std::vector<array>& inputs) {
// Only allowed to dynamically reshape when the shape is {}
if (expressions_.empty() && !shape_.empty()) {
throw std::invalid_argument(
"[Reshape::output_shapes] Unable to infer output shape.");
}
return {shape_from_expressions(expressions_, char_to_dim_, inputs[0])};
}
std::vector<array> Reduce::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,

View File

@@ -1609,6 +1609,14 @@ class Reshape : public UnaryPrimitive {
explicit Reshape(Stream stream, const Shape& shape)
: UnaryPrimitive(stream), shape_(shape) {}
explicit Reshape(
Stream stream,
std::vector<std::variant<int, std::string>> expressions,
std::unordered_map<char, int> char_to_dim)
: UnaryPrimitive(stream),
expressions_(std::move(expressions)),
char_to_dim_(std::move(char_to_dim)) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
@@ -1616,9 +1624,17 @@ class Reshape : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_PRINT(Reshape)
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
static Shape shape_from_expressions(
const std::vector<std::variant<int, std::string>>& expressions,
const std::unordered_map<char, int>& char_to_dim,
const array& in);
private:
Shape shape_;
std::vector<std::variant<int, std::string>> expressions_;
std::unordered_map<char, int> char_to_dim_;
void eval(const std::vector<array>& inputs, array& out);

View File

@@ -4880,4 +4880,27 @@ void init_ops(nb::module_& m) {
Returns:
array: The imaginary part of ``a``.
)pbdoc");
m.def(
"dynamic_reshape",
&dynamic_reshape,
nb::arg(),
"expressions"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def dynamic_reshape(a: array, /, expressions: Sequence[Union[int, str]], *, stream: "
"Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Dynamically reshape an array based on the given expression.
Args:
a (array): Input array.
expressions (tuple(int or str)): The expressions which determine
the output shape.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The reshaped array.
)pbdoc");
}

View File

@@ -809,6 +809,29 @@ class TestCompile(mlx_tests.MLXTestCase):
out = fun(*inputs)
self.assertTrue(mx.allclose(out, mx.full((2, 2), 20)))
def test_compile_shapeless_with_reshape(self):
def fun(a):
return mx.reshape(a, (4, 7, 4, 2))
cfun = mx.compile(fun, shapeless=True)
a = mx.zeros((4, 7, 8))
with self.assertRaises(ValueError):
b = cfun(a)
def fun(a):
return mx.dynamic_reshape(a, ("B", "L", 4, 2))
cfun = mx.compile(fun, shapeless=True)
b = cfun(a)
self.assertEqual(b.shape, (4, 7, 4, 2))
a = mx.zeros((4, 9, 8))
b = cfun(a)
self.assertEqual(b.shape, (4, 9, 4, 2))
if __name__ == "__main__":
unittest.main()

View File

@@ -2713,6 +2713,21 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.imag(z).dtype, mx.float32)
self.assertTrue(mx.array_equal(mx.imag(z), y))
def test_dynamic_reshape(self):
a = mx.array(1)[None, None]
a = mx.dynamic_reshape(a, ())
self.assertEqual(a.shape, ())
a = mx.zeros((4, 4, 4))
b = mx.dynamic_reshape(a, ("a", "b", "c"))
self.assertEqual(b.shape, (4, 4, 4))
b = mx.dynamic_reshape(a, ("a*b", "c"))
self.assertEqual(b.shape, (4 * 4, 4))
b = mx.dynamic_reshape(a, ("a*b*c", 1, 1))
self.assertEqual(b.shape, (4 * 4 * 4, 1, 1))
if __name__ == "__main__":
unittest.main()

View File

@@ -4,6 +4,7 @@
#define _USE_MATH_DEFINES
#include <cmath>
#include <iostream> // TODO
#include <numeric>
#include "doctest/doctest.h"
@@ -3769,3 +3770,61 @@ TEST_CASE("test contiguous") {
CHECK(x.flags().col_contiguous);
CHECK_EQ(x.strides(), decltype(x.strides()){1, 2});
}
TEST_CASE("test dynamic reshape") {
auto x = array({1}, {1, 1, 1});
CHECK_EQ(dynamic_reshape(x, {}).shape(), Shape{});
// Bad character
CHECK_THROWS(dynamic_reshape(x, {"&", 1, 1}));
// Malformed
CHECK_THROWS(dynamic_reshape(x, {"+a", 1, 1}));
CHECK_THROWS(dynamic_reshape(x, {"a+", 1, 1}));
// No dim in string
CHECK_THROWS(dynamic_reshape(x, {"1", 1, 1}));
// Too many dims
CHECK_THROWS(dynamic_reshape(x, {"abcd", 1, 1}));
// Too many dims
CHECK_THROWS(dynamic_reshape(x, {"a", "b", "c", "d"}));
CHECK_THROWS(dynamic_reshape(x, {"abcd", 1, 1}));
// Too many inferred dims
CHECK_THROWS(dynamic_reshape(x, {"a", -1, -1}));
// Bad sizes
x = zeros({2, 2, 2});
CHECK_THROWS_AS(dynamic_reshape(x, {7}), std::invalid_argument);
CHECK_THROWS_AS(dynamic_reshape(x, {-1, 7}), std::invalid_argument);
// Works with empty array
x = array({});
auto y = dynamic_reshape(x, {0, 0, 0});
CHECK_EQ(y.shape(), Shape{0, 0, 0});
CHECK_THROWS_AS(dynamic_reshape(x, {}), std::invalid_argument);
CHECK_THROWS_AS(dynamic_reshape(x, {1}), std::invalid_argument);
y = dynamic_reshape(x, {1, 5, 0});
CHECK_EQ(y.shape(), Shape{1, 5, 0});
x = array({1, 2, 3});
y = dynamic_reshape(x, {"a", 1, 1});
CHECK_EQ(y.shape(), Shape{3, 1, 1});
x = zeros({2, 2});
y = dynamic_reshape(x, {"a*b"});
CHECK_EQ(y.shape(), Shape{4});
y = dynamic_reshape(x, {"2*a"});
CHECK_EQ(y.shape(), Shape{4});
x = zeros({2, 20});
y = dynamic_reshape(x, {"a*20"});
CHECK_EQ(y.shape(), Shape{40});
x = zeros({2, 20});
y = dynamic_reshape(x, {"a", "b/10", 10});
CHECK_EQ(y.shape(), Shape{2, 2, 10});
}