Fix vmap constant output size (#1524)

* use inputs to determine output size

* remove noop vmap tests
This commit is contained in:
Alex Barron 2024-10-30 16:16:53 -07:00 committed by GitHub
parent 917252a5a1
commit 048fabdabd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 38 additions and 36 deletions

View File

@ -686,6 +686,17 @@ std::vector<array> vmap_replace(
throw std::invalid_argument(msg.str());
}
int vmap_size = -1;
for (int i = 0; i < inputs.size(); ++i) {
if (in_axes[i] >= 0) {
vmap_size = inputs[i].shape(in_axes[i]);
break;
}
}
if (vmap_size == -1) {
throw std::invalid_argument("At least one of in_axes must be non-None.");
}
std::unordered_map<std::uintptr_t, std::pair<array, int>> tmap;
std::unordered_set<std::uintptr_t> needs_vmap;
std::unordered_set<std::uintptr_t> cache;
@ -782,7 +793,11 @@ std::vector<array> vmap_replace(
}
outputs.push_back(out);
} else {
outputs.push_back(s_outputs[i]);
// When the output has no input dependencies
// use the size of the vmapped axis in the inputs to expand the output
array output = expand_dims(s_outputs[i], out_axes[i]);
output = repeat(output, vmap_size, out_axes[i]);
outputs.push_back(output);
}
}
return outputs;

View File

@ -462,6 +462,26 @@ class TestVmap(mlx_tests.MLXTestCase):
expected[:, 0] = mx.array([1, 2, 3])[:, None]
self.assertTrue(mx.allclose(out, expected))
def test_vmap_const_func(self):
a = mx.random.uniform(shape=(2, 3, 4))
b = mx.random.uniform(shape=(4, 3))
def const_func(a, b):
return mx.array(2)
out = mx.vmap(const_func, in_axes=(0, None))(a, b)
self.assertTrue(mx.array_equal(mx.full((2,), 2), out))
out = mx.vmap(const_func, in_axes=(None, 0))(a, b)
self.assertTrue(mx.array_equal(mx.full((4,), 2), out))
out = mx.vmap(const_func, in_axes=(1, 1))(a, b)
self.assertTrue(mx.array_equal(mx.full((3,), 2), out))
with self.assertRaises(ValueError):
out = mx.vmap(const_func, in_axes=(None, None))(a, b)
with self.assertRaises(ValueError):
out = mx.vmap(const_func, in_axes=(0, 0))(a, b)
if __name__ == "__main__":
unittest.main()

View File

@ -34,12 +34,8 @@ TEST_CASE("test simple vmap") {
CHECK_THROWS_AS(vmap(fun, 0, -1), std::invalid_argument);
CHECK_THROWS_AS(vmap(fun, -1, 0), std::invalid_argument);
auto vfun = vmap(fun, -1, -1);
auto x = zeros({2});
CHECK(array_equal(vfun(x), zeros({4, 2})).item<bool>());
vfun = vmap(fun);
x = zeros({3, 2});
auto vfun = vmap(fun);
auto x = zeros({3, 2});
CHECK(array_equal(vfun(x), zeros({3, 4, 2})).item<bool>());
vfun = vmap(fun, 0, 1);
@ -121,16 +117,9 @@ TEST_CASE("test simple vmap") {
out = vfun({x, y})[0];
CHECK(array_equal(out, full({3, 2}, 2.0)).item<bool>());
CHECK_THROWS_AS(vmap(fun, {-1, -1}, {0}), std::invalid_argument);
CHECK_THROWS_AS(vmap(fun, {-1, 0}, {-1}), std::invalid_argument);
CHECK_THROWS_AS(vmap(fun, {0, -1}, {-1}), std::invalid_argument);
x = array(1.);
y = array(1.);
vfun = vmap(fun, {-1, -1}, {-1});
out = vfun({x, y})[0];
CHECK(array_equal(out, array(2.)).item<bool>());
x = ones({3, 2, 1});
y = ones({3, 2, 1});
vfun = vmap(vmap(fun));
@ -187,13 +176,6 @@ TEST_CASE("test simple vmap") {
CHECK_THROWS_AS(vmap(fun, {-1, -1, 0}, {-1}), std::invalid_argument);
CHECK_THROWS_AS(vmap(fun, {0, -1, -1}, {-1}), std::invalid_argument);
cond = array({true, false});
x = array(1.);
y = array(2.);
vfun = vmap(fun, {-1, -1, -1}, {-1});
out = vfun({cond, x, y})[0];
CHECK(array_equal(out, array({1.0, 2.0})).item<bool>());
cond = array({1, 1, 1, 0, 0, 0}, {3, 2, 1});
x = ones({3, 2, 1});
y = full({3, 2, 1}, 2);
@ -424,21 +406,6 @@ TEST_CASE("test vmap scatter") {
};
};
{
// vmap nothing.
auto a = zeros({3, 4});
auto indices = array({1});
auto updates = reshape(array({1, 2}, float32), {1, 1, 2});
auto func = make_scatter_fn({indices}, updates, std::vector<int>{0});
auto out = vmap(func, /* in_axes = */ {-1}, /* out_axes = */ {-1})({a})[0];
auto expected =
array({0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0}, {3, 4}, float32);
// Non-vmapped function output.
CHECK(array_equal(func({a}).at(0), expected).item<bool>());
CHECK(array_equal(out, expected).item<bool>());
}
{
// vmap src on axis 0, scatter on axis 0.
auto a = zeros({2, 3, 4});