diff --git a/mlx/backend/common/inverse.cpp b/mlx/backend/common/inverse.cpp index 2dfc78d21..7c442342e 100644 --- a/mlx/backend/common/inverse.cpp +++ b/mlx/backend/common/inverse.cpp @@ -2,6 +2,7 @@ #include "mlx/allocator.h" #include "mlx/backend/common/copy.h" +#include "mlx/linalg.h" #include "mlx/primitives.h" #ifdef ACCELERATE_NEW_LAPACK @@ -92,4 +93,12 @@ void Inverse::eval(const std::vector& inputs, array& output) { inverse_impl(inputs[0], output); } +std::pair, std::vector> Inverse::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto ax = axes[0] >= 0 ? 0 : -1; + auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; + return {{linalg::inv(a, stream())}, {ax}}; +} + } // namespace mlx::core diff --git a/mlx/backend/common/svd.cpp b/mlx/backend/common/svd.cpp index 412f06297..0b56339aa 100644 --- a/mlx/backend/common/svd.cpp +++ b/mlx/backend/common/svd.cpp @@ -3,6 +3,7 @@ #include "mlx/allocator.h" #include "mlx/backend/common/copy.h" #include "mlx/backend/common/lapack_helper.h" +#include "mlx/linalg.h" #include "mlx/primitives.h" namespace mlx::core { @@ -144,4 +145,12 @@ void SVD::eval(const std::vector& inputs, std::vector& outputs) { svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]); } +std::pair, std::vector> SVD::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto ax = axes[0] >= 0 ? 0 : -1; + auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; + return {{linalg::svd(a, stream())}, {ax, ax, ax}}; +} + } // namespace mlx::core diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index e911d19dd..f7cb29f89 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1127,7 +1127,7 @@ std::pair, std::vector> Equal::vmap( const std::vector& inputs, const std::vector& axes) { auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); - return {{equal(a, b, stream())}, axes}; + return {{equal(a, b, stream())}, {to_ax}}; } std::vector Equal::vjp( @@ -1468,7 +1468,7 @@ std::pair, std::vector> Greater::vmap( const std::vector& inputs, const std::vector& axes) { auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); - return {{greater(a, b, stream())}, axes}; + return {{greater(a, b, stream())}, {to_ax}}; } std::vector Greater::vjp( @@ -1495,7 +1495,7 @@ std::pair, std::vector> GreaterEqual::vmap( const std::vector& inputs, const std::vector& axes) { auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); - return {{greater_equal(a, b, stream())}, axes}; + return {{greater_equal(a, b, stream())}, {to_ax}}; } std::vector GreaterEqual::vjp( @@ -1522,7 +1522,7 @@ std::pair, std::vector> Less::vmap( const std::vector& inputs, const std::vector& axes) { auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); - return {{less(a, b, stream())}, axes}; + return {{less(a, b, stream())}, {to_ax}}; } std::vector Less::vjp( @@ -1549,7 +1549,7 @@ std::pair, std::vector> LessEqual::vmap( const std::vector& inputs, const std::vector& axes) { auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); - return {{less_equal(a, b, stream())}, axes}; + return {{less_equal(a, b, stream())}, {to_ax}}; } std::vector LessEqual::vjp( diff --git a/mlx/primitives.h b/mlx/primitives.h index 8a0a06894..78e759d62 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1929,6 +1929,7 @@ class SVD : public Primitive { void eval_gpu(const std::vector& inputs, std::vector& outputs) override; + DEFINE_VMAP() DEFINE_PRINT(SVD) private: @@ -1943,6 +1944,7 @@ class Inverse : public UnaryPrimitive { void eval_cpu(const std::vector& inputs, array& output) override; void eval_gpu(const std::vector& inputs, array& output) override; + DEFINE_VMAP() DEFINE_PRINT(Inverse) private: diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 7a7a134fd..03e05de1d 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -655,6 +655,7 @@ std::vector vmap_replace( } auto [v_outputs, v_out_axes] = a.primitive().vmap(v_inputs, v_axes); + // For each primitive's outputs add its id, the vout id and the vax auto outputs = a.outputs(); for (int i = 0; i < v_outputs.size(); ++i) { diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index a5d6bcc65..08512fb29 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -314,6 +314,64 @@ class TestVmap(mlx_tests.MLXTestCase): expected = mx.addmm(mx.moveaxis(c, 2, 0), a, mx.moveaxis(b, 1, 0)) self.assertTrue(mx.allclose(out, expected)) + def test_vmap_svd(self): + a = mx.random.uniform(shape=(3, 4, 2)) + + cpu_svd = lambda x: mx.linalg.svd(x, stream=mx.cpu) + + # Vmap over the first axis (this is already supported natively by the primitive). + Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(0,))(a) + self.assertEqual(Us.shape, (a.shape[0], a.shape[1], a.shape[1])) + self.assertEqual(Ss.shape, (a.shape[0], a.shape[2])) + self.assertEqual(Vts.shape, (a.shape[0], a.shape[2], a.shape[2])) + + for i in range(a.shape[0]): + M = a[i] + U, S, Vt = Us[i], Ss[i], Vts[i] + self.assertTrue( + mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7) + ) + + # Vmap over the second axis. + Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(1,))(a) + self.assertEqual(Us.shape, (a.shape[1], a.shape[0], a.shape[0])) + self.assertEqual(Ss.shape, (a.shape[1], a.shape[2])) + self.assertEqual(Vts.shape, (a.shape[1], a.shape[2], a.shape[2])) + + for i in range(a.shape[1]): + M = a[:, i, :] + U, S, Vt = Us[i], Ss[i], Vts[i] + self.assertTrue( + mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7) + ) + + def test_vmap_inverse(self): + a = mx.random.uniform(shape=(3, 4, 4)) + + cpu_inv = lambda x: mx.linalg.inv(x, stream=mx.cpu) + + # Vmap over the first axis (this is already supported natively by the primitive). + invs = mx.vmap(cpu_inv, in_axes=(0,))(a) + + for i in range(a.shape[0]): + self.assertTrue( + mx.allclose(a[i] @ invs[i], mx.eye(a.shape[1]), rtol=0, atol=1e-5) + ) + + a = mx.random.uniform(shape=(4, 3, 4)) + + # Without vmapping, each input matrix is not square. + with self.assertRaises(ValueError): + mx.eval(cpu_inv(a)) + + # Vmap over the second axis. + invs = mx.vmap(cpu_inv, in_axes=(1,))(a) + + for i in range(a.shape[1]): + self.assertTrue( + mx.allclose(a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=0, atol=1e-5) + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/vmap_tests.cpp b/tests/vmap_tests.cpp index 5ff419b1e..f954cad6e 100644 --- a/tests/vmap_tests.cpp +++ b/tests/vmap_tests.cpp @@ -413,3 +413,35 @@ TEST_CASE("test vmap gather") { CHECK_EQ(out.shape(), std::vector{2, 3, 2, 2}); } } + +TEST_CASE("test vmap SVD") { + auto fun = [](std::vector inputs) { + return linalg::svd(inputs.at(0), Device::cpu); + }; + + auto a = astype(reshape(arange(24), {3, 4, 2}), float32); + + // vmap over the second axis. + { + auto out = vmap(fun, /* in_axes = */ {1})({a}); + const auto& U = out.at(0); + const auto& S = out.at(1); + const auto& Vt = out.at(2); + + CHECK_EQ(U.shape(), std::vector{a.shape(1), a.shape(0), a.shape(0)}); + CHECK_EQ(S.shape(), std::vector{a.shape(1), a.shape(2)}); + CHECK_EQ(Vt.shape(), std::vector{a.shape(1), a.shape(2), a.shape(2)}); + } + + // vmap over the third axis. + { + auto out = vmap(fun, /* in_axes = */ {2})({a}); + const auto& U = out.at(0); + const auto& S = out.at(1); + const auto& Vt = out.at(2); + + CHECK_EQ(U.shape(), std::vector{a.shape(2), a.shape(0), a.shape(0)}); + CHECK_EQ(S.shape(), std::vector{a.shape(2), a.shape(0)}); + CHECK_EQ(Vt.shape(), std::vector{a.shape(2), a.shape(1), a.shape(1)}); + } +}