shapeless compile in docs and partially shapeless reshape (#1742)

This commit is contained in:
Awni Hannun 2025-01-02 16:24:42 -08:00 committed by GitHub
parent a64a8dfe45
commit ae69cb15e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 137 additions and 36 deletions

View File

@ -421,3 +421,73 @@ the most opportunity to optimize the computation graph:
# Compiling the outer function is good to do as it will likely
# be faster even though the inner functions are compiled
fun = mx.compile(outer)
Shapeless Compilation
---------------------
When the shape of an input to a compiled function changes, the function is
recompiled. You can compile a function once and run it on inputs with
variable shapes by specifying ``shapeless=True`` to :func:`compile`. In this
case changes to the shapes of the inputs do not cause the function to be
recompiled.
.. code-block:: python
def fun(x, y):
return mx.abs(x + y)
compiled_fun = mx.compile(fun, shapeless=True)
x = mx.array(1.0)
y = mx.array(-2.0)
# Firt call compiles the function
print(compiled_fun(x, y))
# Second call with different shapes
# does not recompile the function
x = mx.array([1.0, -6.0])
y = mx.array([-2.0, 3.0])
print(compiled_fun(x, y))
Use shapeless compilations carefully. Since compilation is not triggered when
shapes change, any graphs which are conditional on the input shapes will not
work as expected. Shape-dependent computations are common and sometimes subtle
to detect. For example:
.. code-block:: python
def fun(x):
return x.reshape(x.shape[0] * x.shape[1], -1)
compiled_fun = mx.compile(fun, shapeless=True)
x = mx.random.uniform(shape=(2, 3, 4))
out = compiled_fun(x)
x = mx.random.uniform(shape=(5, 5, 3))
# Error, can't reshape (5, 5, 3) to (6, -1)
out = compiled_fun(x)
The second call to the ``compiled_fun`` fails because of the call to
:func:`reshape` which uses the static shape of ``x`` in the first call. We can
fix this by using :func:`flatten` to avoid hardcoding the shape of ``x``:
.. code-block:: python
def fun(x):
return x.flatten(0, 1)
compiled_fun = mx.compile(fun, shapeless=True)
x = mx.random.uniform(shape=(2, 3, 4))
out = compiled_fun(x)
x = mx.random.uniform(shape=(5, 5, 3))
# Ok
out = compiled_fun(x)

View File

@ -363,41 +363,12 @@ array reshape(const array& a, Shape shape, StreamOrDevice s /* = {} */) {
if (a.shape() == shape) {
return a;
}
size_t size = 1;
int infer_idx = -1;
for (int i = 0; i < shape.size(); ++i) {
if (shape[i] == -1) {
if (infer_idx >= 0) {
throw std::invalid_argument(
"[reshape] Reshape can only infer one dimension.");
}
infer_idx = i;
} else {
size *= shape[i];
}
}
// Infer the shape
if (size > 0) {
if (infer_idx >= 0) {
shape[infer_idx] = a.size() / size;
size *= shape[infer_idx];
}
} else if (infer_idx >= 0) {
throw std::invalid_argument(
"[reshape] Cannot infer the shape of an empty array");
}
// Check that the reshaping is valid
if (a.size() != size) {
std::ostringstream msg;
msg << "[reshape] Cannot reshape array of size " << a.size()
<< " into shape " << shape << ".";
throw std::invalid_argument(msg.str());
}
auto p = std::make_shared<Reshape>(to_stream(s), shape);
return array(std::move(shape), a.dtype(), std::move(p), {a});
auto out_shape = Reshape::output_shape(a, shape);
return array(
std::move(out_shape),
a.dtype(),
std::make_shared<Reshape>(to_stream(s), std::move(shape)),
{a});
}
array unflatten(

View File

@ -3021,6 +3021,44 @@ bool Reshape::is_equivalent(const Primitive& other) const {
return shape_ == r_other.shape_;
}
Shape Reshape::output_shape(const array& input, Shape shape) {
size_t size = 1;
int infer_idx = -1;
for (int i = 0; i < shape.size(); ++i) {
if (shape[i] == -1) {
if (infer_idx >= 0) {
throw std::invalid_argument(
"[reshape] Reshape can only infer one dimension.");
}
infer_idx = i;
} else {
size *= shape[i];
}
}
// Infer the shape
if (size > 0 && infer_idx >= 0) {
shape[infer_idx] = input.size() / size;
size *= shape[infer_idx];
} else if (infer_idx >= 0) {
throw std::invalid_argument(
"[reshape] Cannot infer the shape of an empty array");
}
// Check that the reshaping is valid
if (input.size() != size) {
std::ostringstream msg;
msg << "[reshape] Cannot reshape array of size " << input.size()
<< " into shape " << shape << ".";
throw std::invalid_argument(msg.str());
}
return shape;
}
std::vector<Shape> Reshape::output_shapes(const std::vector<array>& inputs) {
return {output_shape(inputs[0], shape_)};
}
std::vector<array> Reduce::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,

View File

@ -1746,6 +1746,8 @@ class Reshape : public UnaryPrimitive {
std::vector<int> state() const {
return shape_;
};
static Shape output_shape(const array& input, Shape shape);
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
private:
Shape shape_;

View File

@ -830,6 +830,25 @@ class TestCompile(mlx_tests.MLXTestCase):
a = mx.array([0.0, 1.0, 2.0, 3.0, 4.0])
self.assertTrue(mx.allclose(cfun(a), fun(a)))
def test_shapeless_compile_with_reshape(self):
def fun(x):
return x.reshape(x.shape[0] * x.shape[1], -1)
compiled_fun = mx.compile(fun, shapeless=True)
x = mx.zeros(shape=(2, 3, 4))
out = compiled_fun(x)
self.assertEqual(out.shape, (6, 4))
x = mx.zeros(shape=(2, 3, 8))
out = compiled_fun(x)
self.assertEqual(out.shape, (6, 8))
x = mx.zeros(shape=(5, 5, 5))
with self.assertRaises(ValueError):
compiled_fun(x)
if __name__ == "__main__":
unittest.main()

View File

@ -685,7 +685,8 @@ auto compile_shapeless_ok(const std::vector<array>& inputs) {
TEST_CASE("test shapeless compile") {
{
auto cfun = compile(compile_shapeless_not_ok, /* shapeless */ true);
CHECK_THROWS(cfun({array({1, 2, 3, 4})}));
cfun({array({1, 2, 3, 4})});
CHECK_THROWS(cfun({array({1, 2, 3, 4, 5})}));
}
{