mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Fix vmap constant output size (#1524)
* use inputs to determine output size * remove noop vmap tests
This commit is contained in:
parent
917252a5a1
commit
048fabdabd
@ -686,6 +686,17 @@ std::vector<array> vmap_replace(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int vmap_size = -1;
|
||||||
|
for (int i = 0; i < inputs.size(); ++i) {
|
||||||
|
if (in_axes[i] >= 0) {
|
||||||
|
vmap_size = inputs[i].shape(in_axes[i]);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (vmap_size == -1) {
|
||||||
|
throw std::invalid_argument("At least one of in_axes must be non-None.");
|
||||||
|
}
|
||||||
|
|
||||||
std::unordered_map<std::uintptr_t, std::pair<array, int>> tmap;
|
std::unordered_map<std::uintptr_t, std::pair<array, int>> tmap;
|
||||||
std::unordered_set<std::uintptr_t> needs_vmap;
|
std::unordered_set<std::uintptr_t> needs_vmap;
|
||||||
std::unordered_set<std::uintptr_t> cache;
|
std::unordered_set<std::uintptr_t> cache;
|
||||||
@ -782,7 +793,11 @@ std::vector<array> vmap_replace(
|
|||||||
}
|
}
|
||||||
outputs.push_back(out);
|
outputs.push_back(out);
|
||||||
} else {
|
} else {
|
||||||
outputs.push_back(s_outputs[i]);
|
// When the output has no input dependencies
|
||||||
|
// use the size of the vmapped axis in the inputs to expand the output
|
||||||
|
array output = expand_dims(s_outputs[i], out_axes[i]);
|
||||||
|
output = repeat(output, vmap_size, out_axes[i]);
|
||||||
|
outputs.push_back(output);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return outputs;
|
return outputs;
|
||||||
|
@ -462,6 +462,26 @@ class TestVmap(mlx_tests.MLXTestCase):
|
|||||||
expected[:, 0] = mx.array([1, 2, 3])[:, None]
|
expected[:, 0] = mx.array([1, 2, 3])[:, None]
|
||||||
self.assertTrue(mx.allclose(out, expected))
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
|
def test_vmap_const_func(self):
|
||||||
|
a = mx.random.uniform(shape=(2, 3, 4))
|
||||||
|
b = mx.random.uniform(shape=(4, 3))
|
||||||
|
|
||||||
|
def const_func(a, b):
|
||||||
|
return mx.array(2)
|
||||||
|
|
||||||
|
out = mx.vmap(const_func, in_axes=(0, None))(a, b)
|
||||||
|
self.assertTrue(mx.array_equal(mx.full((2,), 2), out))
|
||||||
|
out = mx.vmap(const_func, in_axes=(None, 0))(a, b)
|
||||||
|
self.assertTrue(mx.array_equal(mx.full((4,), 2), out))
|
||||||
|
out = mx.vmap(const_func, in_axes=(1, 1))(a, b)
|
||||||
|
self.assertTrue(mx.array_equal(mx.full((3,), 2), out))
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
out = mx.vmap(const_func, in_axes=(None, None))(a, b)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
out = mx.vmap(const_func, in_axes=(0, 0))(a, b)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -34,12 +34,8 @@ TEST_CASE("test simple vmap") {
|
|||||||
CHECK_THROWS_AS(vmap(fun, 0, -1), std::invalid_argument);
|
CHECK_THROWS_AS(vmap(fun, 0, -1), std::invalid_argument);
|
||||||
CHECK_THROWS_AS(vmap(fun, -1, 0), std::invalid_argument);
|
CHECK_THROWS_AS(vmap(fun, -1, 0), std::invalid_argument);
|
||||||
|
|
||||||
auto vfun = vmap(fun, -1, -1);
|
auto vfun = vmap(fun);
|
||||||
auto x = zeros({2});
|
auto x = zeros({3, 2});
|
||||||
CHECK(array_equal(vfun(x), zeros({4, 2})).item<bool>());
|
|
||||||
|
|
||||||
vfun = vmap(fun);
|
|
||||||
x = zeros({3, 2});
|
|
||||||
CHECK(array_equal(vfun(x), zeros({3, 4, 2})).item<bool>());
|
CHECK(array_equal(vfun(x), zeros({3, 4, 2})).item<bool>());
|
||||||
|
|
||||||
vfun = vmap(fun, 0, 1);
|
vfun = vmap(fun, 0, 1);
|
||||||
@ -121,16 +117,9 @@ TEST_CASE("test simple vmap") {
|
|||||||
out = vfun({x, y})[0];
|
out = vfun({x, y})[0];
|
||||||
CHECK(array_equal(out, full({3, 2}, 2.0)).item<bool>());
|
CHECK(array_equal(out, full({3, 2}, 2.0)).item<bool>());
|
||||||
|
|
||||||
CHECK_THROWS_AS(vmap(fun, {-1, -1}, {0}), std::invalid_argument);
|
|
||||||
CHECK_THROWS_AS(vmap(fun, {-1, 0}, {-1}), std::invalid_argument);
|
CHECK_THROWS_AS(vmap(fun, {-1, 0}, {-1}), std::invalid_argument);
|
||||||
CHECK_THROWS_AS(vmap(fun, {0, -1}, {-1}), std::invalid_argument);
|
CHECK_THROWS_AS(vmap(fun, {0, -1}, {-1}), std::invalid_argument);
|
||||||
|
|
||||||
x = array(1.);
|
|
||||||
y = array(1.);
|
|
||||||
vfun = vmap(fun, {-1, -1}, {-1});
|
|
||||||
out = vfun({x, y})[0];
|
|
||||||
CHECK(array_equal(out, array(2.)).item<bool>());
|
|
||||||
|
|
||||||
x = ones({3, 2, 1});
|
x = ones({3, 2, 1});
|
||||||
y = ones({3, 2, 1});
|
y = ones({3, 2, 1});
|
||||||
vfun = vmap(vmap(fun));
|
vfun = vmap(vmap(fun));
|
||||||
@ -187,13 +176,6 @@ TEST_CASE("test simple vmap") {
|
|||||||
CHECK_THROWS_AS(vmap(fun, {-1, -1, 0}, {-1}), std::invalid_argument);
|
CHECK_THROWS_AS(vmap(fun, {-1, -1, 0}, {-1}), std::invalid_argument);
|
||||||
CHECK_THROWS_AS(vmap(fun, {0, -1, -1}, {-1}), std::invalid_argument);
|
CHECK_THROWS_AS(vmap(fun, {0, -1, -1}, {-1}), std::invalid_argument);
|
||||||
|
|
||||||
cond = array({true, false});
|
|
||||||
x = array(1.);
|
|
||||||
y = array(2.);
|
|
||||||
vfun = vmap(fun, {-1, -1, -1}, {-1});
|
|
||||||
out = vfun({cond, x, y})[0];
|
|
||||||
CHECK(array_equal(out, array({1.0, 2.0})).item<bool>());
|
|
||||||
|
|
||||||
cond = array({1, 1, 1, 0, 0, 0}, {3, 2, 1});
|
cond = array({1, 1, 1, 0, 0, 0}, {3, 2, 1});
|
||||||
x = ones({3, 2, 1});
|
x = ones({3, 2, 1});
|
||||||
y = full({3, 2, 1}, 2);
|
y = full({3, 2, 1}, 2);
|
||||||
@ -424,21 +406,6 @@ TEST_CASE("test vmap scatter") {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
{
|
|
||||||
// vmap nothing.
|
|
||||||
auto a = zeros({3, 4});
|
|
||||||
auto indices = array({1});
|
|
||||||
auto updates = reshape(array({1, 2}, float32), {1, 1, 2});
|
|
||||||
|
|
||||||
auto func = make_scatter_fn({indices}, updates, std::vector<int>{0});
|
|
||||||
auto out = vmap(func, /* in_axes = */ {-1}, /* out_axes = */ {-1})({a})[0];
|
|
||||||
auto expected =
|
|
||||||
array({0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0}, {3, 4}, float32);
|
|
||||||
// Non-vmapped function output.
|
|
||||||
CHECK(array_equal(func({a}).at(0), expected).item<bool>());
|
|
||||||
CHECK(array_equal(out, expected).item<bool>());
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
{
|
||||||
// vmap src on axis 0, scatter on axis 0.
|
// vmap src on axis 0, scatter on axis 0.
|
||||||
auto a = zeros({2, 3, 4});
|
auto a = zeros({2, 3, 4});
|
||||||
|
Loading…
Reference in New Issue
Block a user