diff --git a/mlx/ops.cpp b/mlx/ops.cpp index d50014ab8..26ac3e98f 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1103,6 +1103,9 @@ array moveaxis( }; source = check_ax(source); destination = check_ax(destination); + if (source == destination) { + return a; + } std::vector reorder(a.ndim()); std::iota(reorder.begin(), reorder.end(), 0); reorder.erase(reorder.begin() + source); @@ -2715,9 +2718,8 @@ array scatter( if (updates.ndim() != (a.ndim() + idx_shape.size())) { std::ostringstream msg; msg << "[scatter] Updates with " << updates.ndim() - << " dimensions does not match the sum of the array and indices " - "dimensions " - << a.ndim() + idx_shape.size() << "."; + << " dimensions does not match the sum of the array (" << a.ndim() + << ") and indices (" << idx_shape.size() << ") dimensions."; throw std::invalid_argument(msg.str()); } for (int i = 0; i < idx_shape.size(); ++i) { @@ -2759,11 +2761,12 @@ array scatter( inputs.insert(inputs.begin(), a); // TODO promote or cast? inputs.push_back(astype(updates, a.dtype(), s)); + return array( a.shape(), a.dtype(), std::make_shared(to_stream(s), mode, axes), - inputs); + std::move(inputs)); } array scatter( diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 46cc3df33..ba818c99d 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1,4 +1,5 @@ // Copyright © 2023-2024 Apple Inc. + #include #include #include @@ -2976,6 +2977,77 @@ std::vector Scatter::jvp( throw std::runtime_error("[scatter] JVP not yet implemented"); } +std::pair, std::vector> Scatter::vmap( + const std::vector& inputs_, + const std::vector& vmap_axes) { + assert(inputs_.size() >= 2); + assert(inputs_.size() == vmap_axes.size()); + + auto inputs = inputs_; + + auto scatter_axes = axes_; + int src_ax = vmap_axes[0]; + + auto vmap_ax_it = std::find_if( + vmap_axes.begin(), vmap_axes.end(), [](int a) { return a >= 0; }); + auto vmap_ax = *vmap_ax_it; + if (vmap_ax >= 0) { + auto vmap_size = inputs[vmap_ax_it - vmap_axes.begin()].shape(vmap_ax); + if (src_ax < 0) { + src_ax = 0; + inputs[0] = + repeat(expand_dims(inputs[0], 0, stream()), vmap_size, 0, stream()); + } + for (int i = 1; i < vmap_axes.size() - 1; ++i) { + // vmap axis for indices goes to 0 + if (vmap_axes[i] >= 0) { + inputs[i] = moveaxis(inputs[i], vmap_axes[i], 0, stream()); + } + // insert a vmap axis and repeat + if (vmap_axes[i] < 0) { + auto idx_shape = inputs[i].shape(); + inputs[i] = + repeat(expand_dims(inputs[i], 0, stream()), vmap_size, 0, stream()); + } + // Adjust non-vmapped index axes to account for the extra vmap dimension. + if (scatter_axes[i - 1] >= src_ax) { + scatter_axes[i - 1]++; + } + } + + auto vmap_inds = arange(vmap_size, inputs[1].dtype(), stream()); + auto vmap_inds_shape = std::vector(inputs[1].ndim(), 1); + vmap_inds_shape[0] = vmap_inds.size(); + vmap_inds = reshape(vmap_inds, std::move(vmap_inds_shape), stream()); + inputs.insert( + inputs.end() - 1, broadcast_to(vmap_inds, inputs[1].shape(), stream())); + scatter_axes.push_back(src_ax); + + // Clone updates along the vmap dimension so they can be applied to each + // source tensor in the vmap. + auto& updates = inputs.back(); + if (vmap_axes.back() < 0) { + updates = expand_dims( + updates, {0, static_cast(inputs[1].ndim())}, stream()); + updates = repeat(updates, vmap_size, 0, stream()); + } else { + updates = + expand_dims(updates, static_cast(inputs[1].ndim()), stream()); + updates = moveaxis(updates, vmap_axes.back(), 0, stream()); + } + } + + auto& shape = inputs[0].shape(); + auto dtype = inputs[0].dtype(); + auto out = array( + shape, + dtype, + std::make_shared(stream(), reduce_type_, scatter_axes), + std::move(inputs)); + + return {{out}, {src_ax}}; +} + std::vector Sigmoid::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index c7fb29f1f..342afdc7b 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1713,7 +1713,9 @@ class Scatter : public UnaryPrimitive { void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; + DEFINE_VMAP(); DEFINE_GRADS(); + void print(std::ostream& os) override { os << "Scatter"; switch (reduce_type_) { diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index c8345e2cc..f8d5648f8 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -563,8 +563,10 @@ std::pair, std::vector> vmap_trace( detail::InTracing in_tracing; if (in_axes.size() != inputs.size()) { - throw std::invalid_argument( - "[vmap] The number of in axes must match the number of inputs."); + std::stringstream ss; + ss << "[vmap] The number of in axes (" << in_axes.size() + << ") must match the number of inputs (" << inputs.size() << ")."; + throw std::invalid_argument(ss.str()); } // Some error checking and get the vmap axis size @@ -620,8 +622,10 @@ std::vector vmap_replace( const std::vector& in_axes, const std::vector& out_axes) { if (out_axes.size() != s_outputs.size()) { - throw std::invalid_argument( - "[vmap] The number of out axes must match the number of outputs."); + std::stringstream msg; + msg << "[vmap] The number of out axes (" << out_axes.size() + << ") must match the number of outputs (" << s_outputs.size() << ")."; + throw std::invalid_argument(msg.str()); } std::unordered_map> tmap; diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index adba57534..cd1d882fb 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -370,6 +370,98 @@ class TestVmap(mlx_tests.MLXTestCase): mx.allclose(a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=0, atol=1e-5) ) + def test_vmap_scatter(self): + def scatter(a): + a[mx.array(0)] = mx.array(0.0) + return a + + a = mx.array([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]]) + out = mx.vmap(scatter)(a) + expected = mx.array([[0.0, 2.0, 3.0], [0.0, 3.0, 4.0]]) + self.assertTrue(mx.allclose(out, expected)) + + out = mx.vmap(scatter, in_axes=(1,), out_axes=1)(a) + expected = mx.array([[0.0, 0.0, 0.0], [2.0, 3.0, 4.0]]) + self.assertTrue(mx.allclose(out, expected)) + + def scatter_add(a): + return a.at[mx.array(0)].add(mx.array(1.0)) + + a = mx.array([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]]) + out = mx.vmap(scatter_add)(a) + expected = mx.array([[2.0, 2.0, 3.0], [3.0, 3.0, 4.0]]) + self.assertTrue(mx.allclose(out, expected)) + + out = mx.vmap(scatter_add, in_axes=(1,), out_axes=1)(a) + expected = mx.array([[2.0, 3.0, 4.0], [2.0, 3.0, 4.0]]) + self.assertTrue(mx.allclose(out, expected)) + + # Multiple indices + def scatter(a): + a[mx.array([0, 1]), mx.array([0, 1])] = mx.array((1.0, 1.0)) + return a + + a = mx.zeros((3, 3, 3)) + + expected = mx.repeat(scatter(mx.zeros((3, 3)))[None], 3, axis=0) + out = mx.vmap(scatter, in_axes=(0,), out_axes=0)(a) + self.assertTrue(mx.allclose(out, expected)) + + expected = mx.zeros((3, 3, 3)) + expected[0, :, 0] = 1 + expected[1, :, 1] = 1 + out = mx.vmap(scatter, in_axes=(1,), out_axes=1)(a) + self.assertTrue(mx.allclose(out, expected)) + + expected = mx.zeros((3, 3, 3)) + expected[0, 0, :] = 1 + expected[1, 1, :] = 1 + out = mx.vmap(scatter, in_axes=(2,), out_axes=2)(a) + self.assertTrue(mx.allclose(out, expected)) + + # vmap over src and indices + def scatter(a, idx): + a[idx] = mx.array(1.0) + return a + + a = mx.zeros((3, 4)) + idx = mx.array([0, 1, 2]) + out = mx.vmap(scatter, in_axes=(0, 0), out_axes=0)(a, idx) + self.assertTrue(mx.allclose(out, mx.eye(n=3, m=4))) + + # vmap over only indices + out = mx.vmap(scatter, in_axes=(None, 0), out_axes=0)(a, idx) + expected = mx.zeros((3, 3, 4)) + expected[0, 0] = 1 + expected[1, 1] = 1 + expected[2, 2] = 1 + self.assertTrue(mx.allclose(out, expected)) + + # vmap over src, indices, updates + def scatter(a, idx, updates): + a[idx] = updates + return a + + a = mx.zeros((3, 4)) + idx = mx.array([0, 1, 2]) + updates = mx.array([1, 2, 3]) + out = mx.vmap(scatter, in_axes=(0, 0, 0), out_axes=0)(a, idx, updates) + expected = mx.diag(mx.array([1, 2, 3]), k=-1)[1:] + self.assertTrue(mx.allclose(out, expected)) + + # vmap over only updates + def scatter(a, idx, updates): + a[idx] = updates + return a + + a = mx.zeros((3, 4)) + idx = mx.array([0]) + updates = mx.array([1, 2, 3]) + out = mx.vmap(scatter, in_axes=(None, None, 0), out_axes=0)(a, idx, updates) + expected = mx.zeros((3, 3, 4)) + expected[:, 0] = mx.array([1, 2, 3])[:, None] + self.assertTrue(mx.allclose(out, expected)) + if __name__ == "__main__": unittest.main() diff --git a/tests/vmap_tests.cpp b/tests/vmap_tests.cpp index f954cad6e..7e87469b2 100644 --- a/tests/vmap_tests.cpp +++ b/tests/vmap_tests.cpp @@ -414,6 +414,94 @@ TEST_CASE("test vmap gather") { } } +TEST_CASE("test vmap scatter") { + auto make_scatter_fn = [](const std::vector& indices, + const array& updates, + const std::vector& axes) { + return [=](const std::vector& inputs) { + auto a = inputs.at(0); + return std::vector{scatter(a, indices, updates, axes)}; + }; + }; + + { + // vmap nothing. + auto a = zeros({3, 4}); + auto indices = array({1}); + auto updates = reshape(array({1, 2}, float32), {1, 1, 2}); + + auto func = make_scatter_fn({indices}, updates, std::vector{0}); + auto out = vmap(func, /* in_axes = */ {-1}, /* out_axes = */ {-1})({a})[0]; + auto expected = + array({0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0}, {3, 4}, float32); + // Non-vmapped function output. + CHECK(array_equal(func({a}).at(0), expected).item()); + CHECK(array_equal(out, expected).item()); + } + + { + // vmap src on axis 0, scatter on axis 0. + auto a = zeros({2, 3, 4}); + auto indices = array({1}); + auto updates = reshape(array({1, 2}, float32), {1, 1, 2}); + + auto func = make_scatter_fn({indices}, updates, std::vector{0}); + auto out = vmap(func, /* in_axes = */ {0})({a})[0]; + auto expected = array( + {0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0}, + {2, 3, 4}, + float32); + CHECK(array_equal(out, expected).item()); + } + + { + // vmap src on axis 1, scatter on axis 0. + auto a = zeros({3, 2, 4}); + auto indices = array({1}); + auto updates = reshape(array({1, 2}, float32), {1, 1, 2}); + + auto func = make_scatter_fn({indices}, updates, std::vector{0}); + auto out = vmap(func, /* in_axes = */ {1}, /* out_axes = */ {1})({a})[0]; + auto expected = array( + {0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, + 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {3, 2, 4}, + float32); + CHECK(array_equal(out, expected).item()); + } + + { + // vmap src on axis 0, scatter on axis 1. + auto a = zeros({2, 3, 4}); + auto indices = array({1}); + auto updates = reshape(array({1, 2}, float32), {1, 2, 1}); + + auto func = make_scatter_fn({indices}, updates, std::vector{1}); + auto out = vmap(func, /* in_axes = */ {0})({a})[0]; + auto expected = array( + {0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0}, + {2, 3, 4}, + float32); + CHECK(array_equal(out, expected).item()); + } + + { + // vmap src on axis 2, scatter on axes (0, 1). + auto a = zeros({2, 3, 2}); + auto indices = {array({1}), array({2})}; + auto axes = {0, 1}; + auto updates = reshape(array({1}, float32), {1, 1, 1}); + + auto func = make_scatter_fn(indices, updates, axes); + auto out = vmap(func, /* in_axes = */ {2}, /* out_axes = */ {2})({a})[0]; + auto expected = + array({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1}, {2, 3, 2}, float32); + CHECK(array_equal(out, expected).item()); + } +} + TEST_CASE("test vmap SVD") { auto fun = [](std::vector inputs) { return linalg::svd(inputs.at(0), Device::cpu);