mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 02:41:19 +08:00
Add vmap to scatter (#1200)
* Add vmap to scatter * updates * vmap updates + a few more tests * bug fix --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
58d0e199e1
commit
8c9f0278b9
11
mlx/ops.cpp
11
mlx/ops.cpp
@ -1103,6 +1103,9 @@ array moveaxis(
|
||||
};
|
||||
source = check_ax(source);
|
||||
destination = check_ax(destination);
|
||||
if (source == destination) {
|
||||
return a;
|
||||
}
|
||||
std::vector<int> reorder(a.ndim());
|
||||
std::iota(reorder.begin(), reorder.end(), 0);
|
||||
reorder.erase(reorder.begin() + source);
|
||||
@ -2715,9 +2718,8 @@ array scatter(
|
||||
if (updates.ndim() != (a.ndim() + idx_shape.size())) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scatter] Updates with " << updates.ndim()
|
||||
<< " dimensions does not match the sum of the array and indices "
|
||||
"dimensions "
|
||||
<< a.ndim() + idx_shape.size() << ".";
|
||||
<< " dimensions does not match the sum of the array (" << a.ndim()
|
||||
<< ") and indices (" << idx_shape.size() << ") dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
for (int i = 0; i < idx_shape.size(); ++i) {
|
||||
@ -2759,11 +2761,12 @@ array scatter(
|
||||
inputs.insert(inputs.begin(), a);
|
||||
// TODO promote or cast?
|
||||
inputs.push_back(astype(updates, a.dtype(), s));
|
||||
|
||||
return array(
|
||||
a.shape(),
|
||||
a.dtype(),
|
||||
std::make_shared<Scatter>(to_stream(s), mode, axes),
|
||||
inputs);
|
||||
std::move(inputs));
|
||||
}
|
||||
|
||||
array scatter(
|
||||
|
@ -1,4 +1,5 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
@ -2976,6 +2977,77 @@ std::vector<array> Scatter::jvp(
|
||||
throw std::runtime_error("[scatter] JVP not yet implemented");
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Scatter::vmap(
|
||||
const std::vector<array>& inputs_,
|
||||
const std::vector<int>& vmap_axes) {
|
||||
assert(inputs_.size() >= 2);
|
||||
assert(inputs_.size() == vmap_axes.size());
|
||||
|
||||
auto inputs = inputs_;
|
||||
|
||||
auto scatter_axes = axes_;
|
||||
int src_ax = vmap_axes[0];
|
||||
|
||||
auto vmap_ax_it = std::find_if(
|
||||
vmap_axes.begin(), vmap_axes.end(), [](int a) { return a >= 0; });
|
||||
auto vmap_ax = *vmap_ax_it;
|
||||
if (vmap_ax >= 0) {
|
||||
auto vmap_size = inputs[vmap_ax_it - vmap_axes.begin()].shape(vmap_ax);
|
||||
if (src_ax < 0) {
|
||||
src_ax = 0;
|
||||
inputs[0] =
|
||||
repeat(expand_dims(inputs[0], 0, stream()), vmap_size, 0, stream());
|
||||
}
|
||||
for (int i = 1; i < vmap_axes.size() - 1; ++i) {
|
||||
// vmap axis for indices goes to 0
|
||||
if (vmap_axes[i] >= 0) {
|
||||
inputs[i] = moveaxis(inputs[i], vmap_axes[i], 0, stream());
|
||||
}
|
||||
// insert a vmap axis and repeat
|
||||
if (vmap_axes[i] < 0) {
|
||||
auto idx_shape = inputs[i].shape();
|
||||
inputs[i] =
|
||||
repeat(expand_dims(inputs[i], 0, stream()), vmap_size, 0, stream());
|
||||
}
|
||||
// Adjust non-vmapped index axes to account for the extra vmap dimension.
|
||||
if (scatter_axes[i - 1] >= src_ax) {
|
||||
scatter_axes[i - 1]++;
|
||||
}
|
||||
}
|
||||
|
||||
auto vmap_inds = arange(vmap_size, inputs[1].dtype(), stream());
|
||||
auto vmap_inds_shape = std::vector<int>(inputs[1].ndim(), 1);
|
||||
vmap_inds_shape[0] = vmap_inds.size();
|
||||
vmap_inds = reshape(vmap_inds, std::move(vmap_inds_shape), stream());
|
||||
inputs.insert(
|
||||
inputs.end() - 1, broadcast_to(vmap_inds, inputs[1].shape(), stream()));
|
||||
scatter_axes.push_back(src_ax);
|
||||
|
||||
// Clone updates along the vmap dimension so they can be applied to each
|
||||
// source tensor in the vmap.
|
||||
auto& updates = inputs.back();
|
||||
if (vmap_axes.back() < 0) {
|
||||
updates = expand_dims(
|
||||
updates, {0, static_cast<int>(inputs[1].ndim())}, stream());
|
||||
updates = repeat(updates, vmap_size, 0, stream());
|
||||
} else {
|
||||
updates =
|
||||
expand_dims(updates, static_cast<int>(inputs[1].ndim()), stream());
|
||||
updates = moveaxis(updates, vmap_axes.back(), 0, stream());
|
||||
}
|
||||
}
|
||||
|
||||
auto& shape = inputs[0].shape();
|
||||
auto dtype = inputs[0].dtype();
|
||||
auto out = array(
|
||||
shape,
|
||||
dtype,
|
||||
std::make_shared<Scatter>(stream(), reduce_type_, scatter_axes),
|
||||
std::move(inputs));
|
||||
|
||||
return {{out}, {src_ax}};
|
||||
}
|
||||
|
||||
std::vector<array> Sigmoid::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
|
@ -1713,7 +1713,9 @@ class Scatter : public UnaryPrimitive {
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP();
|
||||
DEFINE_GRADS();
|
||||
|
||||
void print(std::ostream& os) override {
|
||||
os << "Scatter";
|
||||
switch (reduce_type_) {
|
||||
|
@ -563,8 +563,10 @@ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
|
||||
detail::InTracing in_tracing;
|
||||
|
||||
if (in_axes.size() != inputs.size()) {
|
||||
throw std::invalid_argument(
|
||||
"[vmap] The number of in axes must match the number of inputs.");
|
||||
std::stringstream ss;
|
||||
ss << "[vmap] The number of in axes (" << in_axes.size()
|
||||
<< ") must match the number of inputs (" << inputs.size() << ").";
|
||||
throw std::invalid_argument(ss.str());
|
||||
}
|
||||
|
||||
// Some error checking and get the vmap axis size
|
||||
@ -620,8 +622,10 @@ std::vector<array> vmap_replace(
|
||||
const std::vector<int>& in_axes,
|
||||
const std::vector<int>& out_axes) {
|
||||
if (out_axes.size() != s_outputs.size()) {
|
||||
throw std::invalid_argument(
|
||||
"[vmap] The number of out axes must match the number of outputs.");
|
||||
std::stringstream msg;
|
||||
msg << "[vmap] The number of out axes (" << out_axes.size()
|
||||
<< ") must match the number of outputs (" << s_outputs.size() << ").";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
std::unordered_map<std::uintptr_t, std::pair<array, int>> tmap;
|
||||
|
@ -370,6 +370,98 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
mx.allclose(a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=0, atol=1e-5)
|
||||
)
|
||||
|
||||
def test_vmap_scatter(self):
|
||||
def scatter(a):
|
||||
a[mx.array(0)] = mx.array(0.0)
|
||||
return a
|
||||
|
||||
a = mx.array([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]])
|
||||
out = mx.vmap(scatter)(a)
|
||||
expected = mx.array([[0.0, 2.0, 3.0], [0.0, 3.0, 4.0]])
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
out = mx.vmap(scatter, in_axes=(1,), out_axes=1)(a)
|
||||
expected = mx.array([[0.0, 0.0, 0.0], [2.0, 3.0, 4.0]])
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
def scatter_add(a):
|
||||
return a.at[mx.array(0)].add(mx.array(1.0))
|
||||
|
||||
a = mx.array([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]])
|
||||
out = mx.vmap(scatter_add)(a)
|
||||
expected = mx.array([[2.0, 2.0, 3.0], [3.0, 3.0, 4.0]])
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
out = mx.vmap(scatter_add, in_axes=(1,), out_axes=1)(a)
|
||||
expected = mx.array([[2.0, 3.0, 4.0], [2.0, 3.0, 4.0]])
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
# Multiple indices
|
||||
def scatter(a):
|
||||
a[mx.array([0, 1]), mx.array([0, 1])] = mx.array((1.0, 1.0))
|
||||
return a
|
||||
|
||||
a = mx.zeros((3, 3, 3))
|
||||
|
||||
expected = mx.repeat(scatter(mx.zeros((3, 3)))[None], 3, axis=0)
|
||||
out = mx.vmap(scatter, in_axes=(0,), out_axes=0)(a)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
expected = mx.zeros((3, 3, 3))
|
||||
expected[0, :, 0] = 1
|
||||
expected[1, :, 1] = 1
|
||||
out = mx.vmap(scatter, in_axes=(1,), out_axes=1)(a)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
expected = mx.zeros((3, 3, 3))
|
||||
expected[0, 0, :] = 1
|
||||
expected[1, 1, :] = 1
|
||||
out = mx.vmap(scatter, in_axes=(2,), out_axes=2)(a)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
# vmap over src and indices
|
||||
def scatter(a, idx):
|
||||
a[idx] = mx.array(1.0)
|
||||
return a
|
||||
|
||||
a = mx.zeros((3, 4))
|
||||
idx = mx.array([0, 1, 2])
|
||||
out = mx.vmap(scatter, in_axes=(0, 0), out_axes=0)(a, idx)
|
||||
self.assertTrue(mx.allclose(out, mx.eye(n=3, m=4)))
|
||||
|
||||
# vmap over only indices
|
||||
out = mx.vmap(scatter, in_axes=(None, 0), out_axes=0)(a, idx)
|
||||
expected = mx.zeros((3, 3, 4))
|
||||
expected[0, 0] = 1
|
||||
expected[1, 1] = 1
|
||||
expected[2, 2] = 1
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
# vmap over src, indices, updates
|
||||
def scatter(a, idx, updates):
|
||||
a[idx] = updates
|
||||
return a
|
||||
|
||||
a = mx.zeros((3, 4))
|
||||
idx = mx.array([0, 1, 2])
|
||||
updates = mx.array([1, 2, 3])
|
||||
out = mx.vmap(scatter, in_axes=(0, 0, 0), out_axes=0)(a, idx, updates)
|
||||
expected = mx.diag(mx.array([1, 2, 3]), k=-1)[1:]
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
# vmap over only updates
|
||||
def scatter(a, idx, updates):
|
||||
a[idx] = updates
|
||||
return a
|
||||
|
||||
a = mx.zeros((3, 4))
|
||||
idx = mx.array([0])
|
||||
updates = mx.array([1, 2, 3])
|
||||
out = mx.vmap(scatter, in_axes=(None, None, 0), out_axes=0)(a, idx, updates)
|
||||
expected = mx.zeros((3, 3, 4))
|
||||
expected[:, 0] = mx.array([1, 2, 3])[:, None]
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -414,6 +414,94 @@ TEST_CASE("test vmap gather") {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test vmap scatter") {
|
||||
auto make_scatter_fn = [](const std::vector<array>& indices,
|
||||
const array& updates,
|
||||
const std::vector<int>& axes) {
|
||||
return [=](const std::vector<array>& inputs) {
|
||||
auto a = inputs.at(0);
|
||||
return std::vector<array>{scatter(a, indices, updates, axes)};
|
||||
};
|
||||
};
|
||||
|
||||
{
|
||||
// vmap nothing.
|
||||
auto a = zeros({3, 4});
|
||||
auto indices = array({1});
|
||||
auto updates = reshape(array({1, 2}, float32), {1, 1, 2});
|
||||
|
||||
auto func = make_scatter_fn({indices}, updates, std::vector<int>{0});
|
||||
auto out = vmap(func, /* in_axes = */ {-1}, /* out_axes = */ {-1})({a})[0];
|
||||
auto expected =
|
||||
array({0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0}, {3, 4}, float32);
|
||||
// Non-vmapped function output.
|
||||
CHECK(array_equal(func({a}).at(0), expected).item<bool>());
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
// vmap src on axis 0, scatter on axis 0.
|
||||
auto a = zeros({2, 3, 4});
|
||||
auto indices = array({1});
|
||||
auto updates = reshape(array({1, 2}, float32), {1, 1, 2});
|
||||
|
||||
auto func = make_scatter_fn({indices}, updates, std::vector<int>{0});
|
||||
auto out = vmap(func, /* in_axes = */ {0})({a})[0];
|
||||
auto expected = array(
|
||||
{0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0},
|
||||
{2, 3, 4},
|
||||
float32);
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
// vmap src on axis 1, scatter on axis 0.
|
||||
auto a = zeros({3, 2, 4});
|
||||
auto indices = array({1});
|
||||
auto updates = reshape(array({1, 2}, float32), {1, 1, 2});
|
||||
|
||||
auto func = make_scatter_fn({indices}, updates, std::vector<int>{0});
|
||||
auto out = vmap(func, /* in_axes = */ {1}, /* out_axes = */ {1})({a})[0];
|
||||
auto expected = array(
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0,
|
||||
1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
{3, 2, 4},
|
||||
float32);
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
// vmap src on axis 0, scatter on axis 1.
|
||||
auto a = zeros({2, 3, 4});
|
||||
auto indices = array({1});
|
||||
auto updates = reshape(array({1, 2}, float32), {1, 2, 1});
|
||||
|
||||
auto func = make_scatter_fn({indices}, updates, std::vector<int>{1});
|
||||
auto out = vmap(func, /* in_axes = */ {0})({a})[0];
|
||||
auto expected = array(
|
||||
{0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0,
|
||||
0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0},
|
||||
{2, 3, 4},
|
||||
float32);
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
// vmap src on axis 2, scatter on axes (0, 1).
|
||||
auto a = zeros({2, 3, 2});
|
||||
auto indices = {array({1}), array({2})};
|
||||
auto axes = {0, 1};
|
||||
auto updates = reshape(array({1}, float32), {1, 1, 1});
|
||||
|
||||
auto func = make_scatter_fn(indices, updates, axes);
|
||||
auto out = vmap(func, /* in_axes = */ {2}, /* out_axes = */ {2})({a})[0];
|
||||
auto expected =
|
||||
array({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1}, {2, 3, 2}, float32);
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test vmap SVD") {
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
return linalg::svd(inputs.at(0), Device::cpu);
|
||||
|
Loading…
Reference in New Issue
Block a user