Add vmap to scatter (#1200)

* Add vmap to scatter

* updates

* vmap updates + a few more tests

* bug fix

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
nicolov 2024-08-06 05:12:27 +02:00 committed by GitHub
parent 58d0e199e1
commit 8c9f0278b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 269 additions and 8 deletions

View File

@ -1103,6 +1103,9 @@ array moveaxis(
};
source = check_ax(source);
destination = check_ax(destination);
if (source == destination) {
return a;
}
std::vector<int> reorder(a.ndim());
std::iota(reorder.begin(), reorder.end(), 0);
reorder.erase(reorder.begin() + source);
@ -2715,9 +2718,8 @@ array scatter(
if (updates.ndim() != (a.ndim() + idx_shape.size())) {
std::ostringstream msg;
msg << "[scatter] Updates with " << updates.ndim()
<< " dimensions does not match the sum of the array and indices "
"dimensions "
<< a.ndim() + idx_shape.size() << ".";
<< " dimensions does not match the sum of the array (" << a.ndim()
<< ") and indices (" << idx_shape.size() << ") dimensions.";
throw std::invalid_argument(msg.str());
}
for (int i = 0; i < idx_shape.size(); ++i) {
@ -2759,11 +2761,12 @@ array scatter(
inputs.insert(inputs.begin(), a);
// TODO promote or cast?
inputs.push_back(astype(updates, a.dtype(), s));
return array(
a.shape(),
a.dtype(),
std::make_shared<Scatter>(to_stream(s), mode, axes),
inputs);
std::move(inputs));
}
array scatter(

View File

@ -1,4 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cassert>
#include <cmath>
@ -2976,6 +2977,77 @@ std::vector<array> Scatter::jvp(
throw std::runtime_error("[scatter] JVP not yet implemented");
}
std::pair<std::vector<array>, std::vector<int>> Scatter::vmap(
const std::vector<array>& inputs_,
const std::vector<int>& vmap_axes) {
assert(inputs_.size() >= 2);
assert(inputs_.size() == vmap_axes.size());
auto inputs = inputs_;
auto scatter_axes = axes_;
int src_ax = vmap_axes[0];
auto vmap_ax_it = std::find_if(
vmap_axes.begin(), vmap_axes.end(), [](int a) { return a >= 0; });
auto vmap_ax = *vmap_ax_it;
if (vmap_ax >= 0) {
auto vmap_size = inputs[vmap_ax_it - vmap_axes.begin()].shape(vmap_ax);
if (src_ax < 0) {
src_ax = 0;
inputs[0] =
repeat(expand_dims(inputs[0], 0, stream()), vmap_size, 0, stream());
}
for (int i = 1; i < vmap_axes.size() - 1; ++i) {
// vmap axis for indices goes to 0
if (vmap_axes[i] >= 0) {
inputs[i] = moveaxis(inputs[i], vmap_axes[i], 0, stream());
}
// insert a vmap axis and repeat
if (vmap_axes[i] < 0) {
auto idx_shape = inputs[i].shape();
inputs[i] =
repeat(expand_dims(inputs[i], 0, stream()), vmap_size, 0, stream());
}
// Adjust non-vmapped index axes to account for the extra vmap dimension.
if (scatter_axes[i - 1] >= src_ax) {
scatter_axes[i - 1]++;
}
}
auto vmap_inds = arange(vmap_size, inputs[1].dtype(), stream());
auto vmap_inds_shape = std::vector<int>(inputs[1].ndim(), 1);
vmap_inds_shape[0] = vmap_inds.size();
vmap_inds = reshape(vmap_inds, std::move(vmap_inds_shape), stream());
inputs.insert(
inputs.end() - 1, broadcast_to(vmap_inds, inputs[1].shape(), stream()));
scatter_axes.push_back(src_ax);
// Clone updates along the vmap dimension so they can be applied to each
// source tensor in the vmap.
auto& updates = inputs.back();
if (vmap_axes.back() < 0) {
updates = expand_dims(
updates, {0, static_cast<int>(inputs[1].ndim())}, stream());
updates = repeat(updates, vmap_size, 0, stream());
} else {
updates =
expand_dims(updates, static_cast<int>(inputs[1].ndim()), stream());
updates = moveaxis(updates, vmap_axes.back(), 0, stream());
}
}
auto& shape = inputs[0].shape();
auto dtype = inputs[0].dtype();
auto out = array(
shape,
dtype,
std::make_shared<Scatter>(stream(), reduce_type_, scatter_axes),
std::move(inputs));
return {{out}, {src_ax}};
}
std::vector<array> Sigmoid::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,

View File

@ -1713,7 +1713,9 @@ class Scatter : public UnaryPrimitive {
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP();
DEFINE_GRADS();
void print(std::ostream& os) override {
os << "Scatter";
switch (reduce_type_) {

View File

@ -563,8 +563,10 @@ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
detail::InTracing in_tracing;
if (in_axes.size() != inputs.size()) {
throw std::invalid_argument(
"[vmap] The number of in axes must match the number of inputs.");
std::stringstream ss;
ss << "[vmap] The number of in axes (" << in_axes.size()
<< ") must match the number of inputs (" << inputs.size() << ").";
throw std::invalid_argument(ss.str());
}
// Some error checking and get the vmap axis size
@ -620,8 +622,10 @@ std::vector<array> vmap_replace(
const std::vector<int>& in_axes,
const std::vector<int>& out_axes) {
if (out_axes.size() != s_outputs.size()) {
throw std::invalid_argument(
"[vmap] The number of out axes must match the number of outputs.");
std::stringstream msg;
msg << "[vmap] The number of out axes (" << out_axes.size()
<< ") must match the number of outputs (" << s_outputs.size() << ").";
throw std::invalid_argument(msg.str());
}
std::unordered_map<std::uintptr_t, std::pair<array, int>> tmap;

View File

@ -370,6 +370,98 @@ class TestVmap(mlx_tests.MLXTestCase):
mx.allclose(a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=0, atol=1e-5)
)
def test_vmap_scatter(self):
def scatter(a):
a[mx.array(0)] = mx.array(0.0)
return a
a = mx.array([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]])
out = mx.vmap(scatter)(a)
expected = mx.array([[0.0, 2.0, 3.0], [0.0, 3.0, 4.0]])
self.assertTrue(mx.allclose(out, expected))
out = mx.vmap(scatter, in_axes=(1,), out_axes=1)(a)
expected = mx.array([[0.0, 0.0, 0.0], [2.0, 3.0, 4.0]])
self.assertTrue(mx.allclose(out, expected))
def scatter_add(a):
return a.at[mx.array(0)].add(mx.array(1.0))
a = mx.array([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]])
out = mx.vmap(scatter_add)(a)
expected = mx.array([[2.0, 2.0, 3.0], [3.0, 3.0, 4.0]])
self.assertTrue(mx.allclose(out, expected))
out = mx.vmap(scatter_add, in_axes=(1,), out_axes=1)(a)
expected = mx.array([[2.0, 3.0, 4.0], [2.0, 3.0, 4.0]])
self.assertTrue(mx.allclose(out, expected))
# Multiple indices
def scatter(a):
a[mx.array([0, 1]), mx.array([0, 1])] = mx.array((1.0, 1.0))
return a
a = mx.zeros((3, 3, 3))
expected = mx.repeat(scatter(mx.zeros((3, 3)))[None], 3, axis=0)
out = mx.vmap(scatter, in_axes=(0,), out_axes=0)(a)
self.assertTrue(mx.allclose(out, expected))
expected = mx.zeros((3, 3, 3))
expected[0, :, 0] = 1
expected[1, :, 1] = 1
out = mx.vmap(scatter, in_axes=(1,), out_axes=1)(a)
self.assertTrue(mx.allclose(out, expected))
expected = mx.zeros((3, 3, 3))
expected[0, 0, :] = 1
expected[1, 1, :] = 1
out = mx.vmap(scatter, in_axes=(2,), out_axes=2)(a)
self.assertTrue(mx.allclose(out, expected))
# vmap over src and indices
def scatter(a, idx):
a[idx] = mx.array(1.0)
return a
a = mx.zeros((3, 4))
idx = mx.array([0, 1, 2])
out = mx.vmap(scatter, in_axes=(0, 0), out_axes=0)(a, idx)
self.assertTrue(mx.allclose(out, mx.eye(n=3, m=4)))
# vmap over only indices
out = mx.vmap(scatter, in_axes=(None, 0), out_axes=0)(a, idx)
expected = mx.zeros((3, 3, 4))
expected[0, 0] = 1
expected[1, 1] = 1
expected[2, 2] = 1
self.assertTrue(mx.allclose(out, expected))
# vmap over src, indices, updates
def scatter(a, idx, updates):
a[idx] = updates
return a
a = mx.zeros((3, 4))
idx = mx.array([0, 1, 2])
updates = mx.array([1, 2, 3])
out = mx.vmap(scatter, in_axes=(0, 0, 0), out_axes=0)(a, idx, updates)
expected = mx.diag(mx.array([1, 2, 3]), k=-1)[1:]
self.assertTrue(mx.allclose(out, expected))
# vmap over only updates
def scatter(a, idx, updates):
a[idx] = updates
return a
a = mx.zeros((3, 4))
idx = mx.array([0])
updates = mx.array([1, 2, 3])
out = mx.vmap(scatter, in_axes=(None, None, 0), out_axes=0)(a, idx, updates)
expected = mx.zeros((3, 3, 4))
expected[:, 0] = mx.array([1, 2, 3])[:, None]
self.assertTrue(mx.allclose(out, expected))
if __name__ == "__main__":
unittest.main()

View File

@ -414,6 +414,94 @@ TEST_CASE("test vmap gather") {
}
}
TEST_CASE("test vmap scatter") {
auto make_scatter_fn = [](const std::vector<array>& indices,
const array& updates,
const std::vector<int>& axes) {
return [=](const std::vector<array>& inputs) {
auto a = inputs.at(0);
return std::vector<array>{scatter(a, indices, updates, axes)};
};
};
{
// vmap nothing.
auto a = zeros({3, 4});
auto indices = array({1});
auto updates = reshape(array({1, 2}, float32), {1, 1, 2});
auto func = make_scatter_fn({indices}, updates, std::vector<int>{0});
auto out = vmap(func, /* in_axes = */ {-1}, /* out_axes = */ {-1})({a})[0];
auto expected =
array({0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0}, {3, 4}, float32);
// Non-vmapped function output.
CHECK(array_equal(func({a}).at(0), expected).item<bool>());
CHECK(array_equal(out, expected).item<bool>());
}
{
// vmap src on axis 0, scatter on axis 0.
auto a = zeros({2, 3, 4});
auto indices = array({1});
auto updates = reshape(array({1, 2}, float32), {1, 1, 2});
auto func = make_scatter_fn({indices}, updates, std::vector<int>{0});
auto out = vmap(func, /* in_axes = */ {0})({a})[0];
auto expected = array(
{0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0},
{2, 3, 4},
float32);
CHECK(array_equal(out, expected).item<bool>());
}
{
// vmap src on axis 1, scatter on axis 0.
auto a = zeros({3, 2, 4});
auto indices = array({1});
auto updates = reshape(array({1, 2}, float32), {1, 1, 2});
auto func = make_scatter_fn({indices}, updates, std::vector<int>{0});
auto out = vmap(func, /* in_axes = */ {1}, /* out_axes = */ {1})({a})[0];
auto expected = array(
{0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0,
1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{3, 2, 4},
float32);
CHECK(array_equal(out, expected).item<bool>());
}
{
// vmap src on axis 0, scatter on axis 1.
auto a = zeros({2, 3, 4});
auto indices = array({1});
auto updates = reshape(array({1, 2}, float32), {1, 2, 1});
auto func = make_scatter_fn({indices}, updates, std::vector<int>{1});
auto out = vmap(func, /* in_axes = */ {0})({a})[0];
auto expected = array(
{0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0,
0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0},
{2, 3, 4},
float32);
CHECK(array_equal(out, expected).item<bool>());
}
{
// vmap src on axis 2, scatter on axes (0, 1).
auto a = zeros({2, 3, 2});
auto indices = {array({1}), array({2})};
auto axes = {0, 1};
auto updates = reshape(array({1}, float32), {1, 1, 1});
auto func = make_scatter_fn(indices, updates, axes);
auto out = vmap(func, /* in_axes = */ {2}, /* out_axes = */ {2})({a})[0];
auto expected =
array({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1}, {2, 3, 2}, float32);
CHECK(array_equal(out, expected).item<bool>());
}
}
TEST_CASE("test vmap SVD") {
auto fun = [](std::vector<array> inputs) {
return linalg::svd(inputs.at(0), Device::cpu);