diff --git a/docs/src/usage/compile.rst b/docs/src/usage/compile.rst index 091505fe4..d0c13a9a8 100644 --- a/docs/src/usage/compile.rst +++ b/docs/src/usage/compile.rst @@ -421,3 +421,73 @@ the most opportunity to optimize the computation graph: # Compiling the outer function is good to do as it will likely # be faster even though the inner functions are compiled fun = mx.compile(outer) + +Shapeless Compilation +--------------------- + +When the shape of an input to a compiled function changes, the function is +recompiled. You can compile a function once and run it on inputs with +variable shapes by specifying ``shapeless=True`` to :func:`compile`. In this +case changes to the shapes of the inputs do not cause the function to be +recompiled. + +.. code-block:: python + + def fun(x, y): + return mx.abs(x + y) + + compiled_fun = mx.compile(fun, shapeless=True) + + x = mx.array(1.0) + y = mx.array(-2.0) + + # Firt call compiles the function + print(compiled_fun(x, y)) + + # Second call with different shapes + # does not recompile the function + x = mx.array([1.0, -6.0]) + y = mx.array([-2.0, 3.0]) + print(compiled_fun(x, y)) + + +Use shapeless compilations carefully. Since compilation is not triggered when +shapes change, any graphs which are conditional on the input shapes will not +work as expected. Shape-dependent computations are common and sometimes subtle +to detect. For example: + +.. code-block:: python + + def fun(x): + return x.reshape(x.shape[0] * x.shape[1], -1) + + compiled_fun = mx.compile(fun, shapeless=True) + + x = mx.random.uniform(shape=(2, 3, 4)) + + out = compiled_fun(x) + + x = mx.random.uniform(shape=(5, 5, 3)) + + # Error, can't reshape (5, 5, 3) to (6, -1) + out = compiled_fun(x) + +The second call to the ``compiled_fun`` fails because of the call to +:func:`reshape` which uses the static shape of ``x`` in the first call. We can +fix this by using :func:`flatten` to avoid hardcoding the shape of ``x``: + +.. code-block:: python + + def fun(x): + return x.flatten(0, 1) + + compiled_fun = mx.compile(fun, shapeless=True) + + x = mx.random.uniform(shape=(2, 3, 4)) + + out = compiled_fun(x) + + x = mx.random.uniform(shape=(5, 5, 3)) + + # Ok + out = compiled_fun(x) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 99b6a721e..ac7238904 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -363,41 +363,12 @@ array reshape(const array& a, Shape shape, StreamOrDevice s /* = {} */) { if (a.shape() == shape) { return a; } - - size_t size = 1; - int infer_idx = -1; - for (int i = 0; i < shape.size(); ++i) { - if (shape[i] == -1) { - if (infer_idx >= 0) { - throw std::invalid_argument( - "[reshape] Reshape can only infer one dimension."); - } - infer_idx = i; - } else { - size *= shape[i]; - } - } - - // Infer the shape - if (size > 0) { - if (infer_idx >= 0) { - shape[infer_idx] = a.size() / size; - size *= shape[infer_idx]; - } - } else if (infer_idx >= 0) { - throw std::invalid_argument( - "[reshape] Cannot infer the shape of an empty array"); - } - - // Check that the reshaping is valid - if (a.size() != size) { - std::ostringstream msg; - msg << "[reshape] Cannot reshape array of size " << a.size() - << " into shape " << shape << "."; - throw std::invalid_argument(msg.str()); - } - auto p = std::make_shared(to_stream(s), shape); - return array(std::move(shape), a.dtype(), std::move(p), {a}); + auto out_shape = Reshape::output_shape(a, shape); + return array( + std::move(out_shape), + a.dtype(), + std::make_shared(to_stream(s), std::move(shape)), + {a}); } array unflatten( diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index bf693bcb8..f68c69b2b 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3021,6 +3021,44 @@ bool Reshape::is_equivalent(const Primitive& other) const { return shape_ == r_other.shape_; } +Shape Reshape::output_shape(const array& input, Shape shape) { + size_t size = 1; + int infer_idx = -1; + for (int i = 0; i < shape.size(); ++i) { + if (shape[i] == -1) { + if (infer_idx >= 0) { + throw std::invalid_argument( + "[reshape] Reshape can only infer one dimension."); + } + infer_idx = i; + } else { + size *= shape[i]; + } + } + + // Infer the shape + if (size > 0 && infer_idx >= 0) { + shape[infer_idx] = input.size() / size; + size *= shape[infer_idx]; + } else if (infer_idx >= 0) { + throw std::invalid_argument( + "[reshape] Cannot infer the shape of an empty array"); + } + + // Check that the reshaping is valid + if (input.size() != size) { + std::ostringstream msg; + msg << "[reshape] Cannot reshape array of size " << input.size() + << " into shape " << shape << "."; + throw std::invalid_argument(msg.str()); + } + return shape; +} + +std::vector Reshape::output_shapes(const std::vector& inputs) { + return {output_shape(inputs[0], shape_)}; +} + std::vector Reduce::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index a56ed4abd..74e50d4fb 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1746,6 +1746,8 @@ class Reshape : public UnaryPrimitive { std::vector state() const { return shape_; }; + static Shape output_shape(const array& input, Shape shape); + std::vector output_shapes(const std::vector& inputs) override; private: Shape shape_; diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 9dfd234b4..85a6036d7 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -830,6 +830,25 @@ class TestCompile(mlx_tests.MLXTestCase): a = mx.array([0.0, 1.0, 2.0, 3.0, 4.0]) self.assertTrue(mx.allclose(cfun(a), fun(a))) + def test_shapeless_compile_with_reshape(self): + def fun(x): + return x.reshape(x.shape[0] * x.shape[1], -1) + + compiled_fun = mx.compile(fun, shapeless=True) + + x = mx.zeros(shape=(2, 3, 4)) + out = compiled_fun(x) + self.assertEqual(out.shape, (6, 4)) + + x = mx.zeros(shape=(2, 3, 8)) + out = compiled_fun(x) + self.assertEqual(out.shape, (6, 8)) + + x = mx.zeros(shape=(5, 5, 5)) + + with self.assertRaises(ValueError): + compiled_fun(x) + if __name__ == "__main__": unittest.main() diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index 9ca4cf19f..5026c8505 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -685,7 +685,8 @@ auto compile_shapeless_ok(const std::vector& inputs) { TEST_CASE("test shapeless compile") { { auto cfun = compile(compile_shapeless_not_ok, /* shapeless */ true); - CHECK_THROWS(cfun({array({1, 2, 3, 4})})); + cfun({array({1, 2, 3, 4})}); + CHECK_THROWS(cfun({array({1, 2, 3, 4, 5})})); } {