From 048fabdabd2fdb0135446291e84f88b61f898ee9 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Wed, 30 Oct 2024 16:16:53 -0700 Subject: [PATCH] Fix vmap constant output size (#1524) * use inputs to determine output size * remove noop vmap tests --- mlx/transforms.cpp | 17 ++++++++++++++++- python/tests/test_vmap.py | 20 ++++++++++++++++++++ tests/vmap_tests.cpp | 37 ++----------------------------------- 3 files changed, 38 insertions(+), 36 deletions(-) diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 6c6f81868..1d5127389 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -686,6 +686,17 @@ std::vector vmap_replace( throw std::invalid_argument(msg.str()); } + int vmap_size = -1; + for (int i = 0; i < inputs.size(); ++i) { + if (in_axes[i] >= 0) { + vmap_size = inputs[i].shape(in_axes[i]); + break; + } + } + if (vmap_size == -1) { + throw std::invalid_argument("At least one of in_axes must be non-None."); + } + std::unordered_map> tmap; std::unordered_set needs_vmap; std::unordered_set cache; @@ -782,7 +793,11 @@ std::vector vmap_replace( } outputs.push_back(out); } else { - outputs.push_back(s_outputs[i]); + // When the output has no input dependencies + // use the size of the vmapped axis in the inputs to expand the output + array output = expand_dims(s_outputs[i], out_axes[i]); + output = repeat(output, vmap_size, out_axes[i]); + outputs.push_back(output); } } return outputs; diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index cd1d882fb..866012a12 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -462,6 +462,26 @@ class TestVmap(mlx_tests.MLXTestCase): expected[:, 0] = mx.array([1, 2, 3])[:, None] self.assertTrue(mx.allclose(out, expected)) + def test_vmap_const_func(self): + a = mx.random.uniform(shape=(2, 3, 4)) + b = mx.random.uniform(shape=(4, 3)) + + def const_func(a, b): + return mx.array(2) + + out = mx.vmap(const_func, in_axes=(0, None))(a, b) + self.assertTrue(mx.array_equal(mx.full((2,), 2), out)) + out = mx.vmap(const_func, in_axes=(None, 0))(a, b) + self.assertTrue(mx.array_equal(mx.full((4,), 2), out)) + out = mx.vmap(const_func, in_axes=(1, 1))(a, b) + self.assertTrue(mx.array_equal(mx.full((3,), 2), out)) + + with self.assertRaises(ValueError): + out = mx.vmap(const_func, in_axes=(None, None))(a, b) + + with self.assertRaises(ValueError): + out = mx.vmap(const_func, in_axes=(0, 0))(a, b) + if __name__ == "__main__": unittest.main() diff --git a/tests/vmap_tests.cpp b/tests/vmap_tests.cpp index 7e87469b2..1403d87ca 100644 --- a/tests/vmap_tests.cpp +++ b/tests/vmap_tests.cpp @@ -34,12 +34,8 @@ TEST_CASE("test simple vmap") { CHECK_THROWS_AS(vmap(fun, 0, -1), std::invalid_argument); CHECK_THROWS_AS(vmap(fun, -1, 0), std::invalid_argument); - auto vfun = vmap(fun, -1, -1); - auto x = zeros({2}); - CHECK(array_equal(vfun(x), zeros({4, 2})).item()); - - vfun = vmap(fun); - x = zeros({3, 2}); + auto vfun = vmap(fun); + auto x = zeros({3, 2}); CHECK(array_equal(vfun(x), zeros({3, 4, 2})).item()); vfun = vmap(fun, 0, 1); @@ -121,16 +117,9 @@ TEST_CASE("test simple vmap") { out = vfun({x, y})[0]; CHECK(array_equal(out, full({3, 2}, 2.0)).item()); - CHECK_THROWS_AS(vmap(fun, {-1, -1}, {0}), std::invalid_argument); CHECK_THROWS_AS(vmap(fun, {-1, 0}, {-1}), std::invalid_argument); CHECK_THROWS_AS(vmap(fun, {0, -1}, {-1}), std::invalid_argument); - x = array(1.); - y = array(1.); - vfun = vmap(fun, {-1, -1}, {-1}); - out = vfun({x, y})[0]; - CHECK(array_equal(out, array(2.)).item()); - x = ones({3, 2, 1}); y = ones({3, 2, 1}); vfun = vmap(vmap(fun)); @@ -187,13 +176,6 @@ TEST_CASE("test simple vmap") { CHECK_THROWS_AS(vmap(fun, {-1, -1, 0}, {-1}), std::invalid_argument); CHECK_THROWS_AS(vmap(fun, {0, -1, -1}, {-1}), std::invalid_argument); - cond = array({true, false}); - x = array(1.); - y = array(2.); - vfun = vmap(fun, {-1, -1, -1}, {-1}); - out = vfun({cond, x, y})[0]; - CHECK(array_equal(out, array({1.0, 2.0})).item()); - cond = array({1, 1, 1, 0, 0, 0}, {3, 2, 1}); x = ones({3, 2, 1}); y = full({3, 2, 1}, 2); @@ -424,21 +406,6 @@ TEST_CASE("test vmap scatter") { }; }; - { - // vmap nothing. - auto a = zeros({3, 4}); - auto indices = array({1}); - auto updates = reshape(array({1, 2}, float32), {1, 1, 2}); - - auto func = make_scatter_fn({indices}, updates, std::vector{0}); - auto out = vmap(func, /* in_axes = */ {-1}, /* out_axes = */ {-1})({a})[0]; - auto expected = - array({0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0}, {3, 4}, float32); - // Non-vmapped function output. - CHECK(array_equal(func({a}).at(0), expected).item()); - CHECK(array_equal(out, expected).item()); - } - { // vmap src on axis 0, scatter on axis 0. auto a = zeros({2, 3, 4});