From 19ec0232566470117b5a474f21a06bf1e117ad4d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 14 Mar 2024 14:38:22 -0700 Subject: [PATCH] vmap matmul and admm (#836) --- mlx/primitives.cpp | 41 ++++++++++++++++++++++++++++++++++++--- mlx/primitives.h | 2 ++ python/tests/test_vmap.py | 39 +++++++++++++++++++++++++++++++++++++ 3 files changed, 79 insertions(+), 3 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index d729ea9ed..f2b7f16af 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -110,7 +110,11 @@ std::vector Primitive::jvp( const std::vector&, const std::vector&, const std::vector&) { - throw std::invalid_argument("Primitive's jvp not implemented."); + std::ostringstream msg; + msg << "[Primitive::jvp] Not implemented for "; + print(msg); + msg << "."; + throw std::invalid_argument(msg.str()); }; std::vector Primitive::vjp( @@ -118,13 +122,21 @@ std::vector Primitive::vjp( const std::vector&, const std::vector&, const std::vector&) { - throw std::invalid_argument("Primitive's vjp not implemented."); + std::ostringstream msg; + msg << "[Primitive::vip] Not implemented for "; + print(msg); + msg << "."; + throw std::invalid_argument(msg.str()); }; std::pair, std::vector> Primitive::vmap( const std::vector&, const std::vector&) { - throw std::invalid_argument("Primitive's vmap not implemented."); + std::ostringstream msg; + msg << "[Primitive::vmap] Not implemented for "; + print(msg); + msg << "."; + throw std::invalid_argument(msg.str()); }; std::vector> Primitive::output_shapes( @@ -235,6 +247,18 @@ bool AddMM::is_equivalent(const Primitive& other) const { return (alpha_ == a_other.alpha_ && beta_ == a_other.beta_); } +std::pair, std::vector> AddMM::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto maybe_move_ax = [this](auto& arr, auto ax) { + return ax > 0 ? moveaxis(arr, ax, 0, stream()) : arr; + }; + auto a = maybe_move_ax(inputs[0], axes[0]); + auto b = maybe_move_ax(inputs[1], axes[1]); + auto c = maybe_move_ax(inputs[2], axes[2]); + return {{addmm(c, a, b, alpha_, beta_, stream())}, {0}}; +} + bool Arange::is_equivalent(const Primitive& other) const { const Arange& a_other = static_cast(other); return ( @@ -1772,6 +1796,17 @@ std::vector Matmul::vjp( return vjps; } +std::pair, std::vector> Matmul::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto maybe_move_ax = [this](auto& arr, auto ax) { + return ax > 0 ? moveaxis(arr, ax, 0, stream()) : arr; + }; + auto a = maybe_move_ax(inputs[0], axes[0]); + auto b = maybe_move_ax(inputs[1], axes[1]); + return {{matmul(a, b, stream())}, {0}}; +} + std::vector Maximum::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index 3b79231de..394d8ecb7 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -200,6 +200,7 @@ class AddMM : public UnaryPrimitive { const std::vector& argnums, const std::vector& outputs) override; + DEFINE_VMAP() DEFINE_PRINT(AddMM) bool is_equivalent(const Primitive& other) const override; @@ -1140,6 +1141,7 @@ class Matmul : public UnaryPrimitive { const std::vector& argnums, const std::vector& outputs) override; + DEFINE_VMAP() DEFINE_PRINT(Matmul) DEFINE_DEFAULT_IS_EQUIVALENT() }; diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index 99e34c21b..a5d6bcc65 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -275,6 +275,45 @@ class TestVmap(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): out = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))(a, b) + def test_vmap_matmul(self): + a = mx.random.uniform(shape=(2, 3, 4)) + b = mx.random.uniform(shape=(4, 3)) + + # matmul + out = mx.vmap(mx.matmul, in_axes=(0, None))(a, b) + self.assertTrue(mx.allclose(out, a @ b)) + + # addmm + c = mx.random.uniform(shape=(3,)) + out = mx.vmap(mx.addmm, in_axes=(None, 0, None))(c, a, b) + self.assertTrue(mx.allclose(out, mx.addmm(c, a, b))) + + b = mx.random.uniform(shape=(4, 2)) + + # matmul + out = mx.vmap(mx.matmul, in_axes=(1, None), out_axes=(1,))(a, b) + expected = mx.moveaxis(mx.moveaxis(a, 1, 0) @ b, 0, 1) + self.assertTrue(mx.allclose(out, expected)) + + # addmm + c = mx.random.uniform(shape=(2,)) + out = mx.vmap(mx.addmm, in_axes=(None, 1, None))(c, a, b) + self.assertTrue(mx.allclose(out, mx.addmm(c, mx.moveaxis(a, 1, 0), b))) + + a = mx.random.uniform(shape=(2, 3, 4)) + b = mx.random.uniform(shape=(4, 2, 3)) + + # matmul + out = mx.vmap(mx.matmul, in_axes=(0, 1))(a, b) + expected = a @ mx.moveaxis(b, 1, 0) + self.assertTrue(mx.allclose(out, expected)) + + # addmm + c = mx.random.uniform(shape=(3, 3, 2)) + out = mx.vmap(mx.addmm, in_axes=(2, 0, 1))(c, a, b) + expected = mx.addmm(mx.moveaxis(c, 2, 0), a, mx.moveaxis(b, 1, 0)) + self.assertTrue(mx.allclose(out, expected)) + if __name__ == "__main__": unittest.main()