Add vmap to scatter (#1200)

* Add vmap to scatter

* updates

* vmap updates + a few more tests

* bug fix

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
nicolov
2024-08-06 05:12:27 +02:00
committed by GitHub
parent 58d0e199e1
commit 8c9f0278b9
6 changed files with 269 additions and 8 deletions

View File

@@ -1103,6 +1103,9 @@ array moveaxis(
};
source = check_ax(source);
destination = check_ax(destination);
if (source == destination) {
return a;
}
std::vector<int> reorder(a.ndim());
std::iota(reorder.begin(), reorder.end(), 0);
reorder.erase(reorder.begin() + source);
@@ -2715,9 +2718,8 @@ array scatter(
if (updates.ndim() != (a.ndim() + idx_shape.size())) {
std::ostringstream msg;
msg << "[scatter] Updates with " << updates.ndim()
<< " dimensions does not match the sum of the array and indices "
"dimensions "
<< a.ndim() + idx_shape.size() << ".";
<< " dimensions does not match the sum of the array (" << a.ndim()
<< ") and indices (" << idx_shape.size() << ") dimensions.";
throw std::invalid_argument(msg.str());
}
for (int i = 0; i < idx_shape.size(); ++i) {
@@ -2759,11 +2761,12 @@ array scatter(
inputs.insert(inputs.begin(), a);
// TODO promote or cast?
inputs.push_back(astype(updates, a.dtype(), s));
return array(
a.shape(),
a.dtype(),
std::make_shared<Scatter>(to_stream(s), mode, axes),
inputs);
std::move(inputs));
}
array scatter(

View File

@@ -1,4 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cassert>
#include <cmath>
@@ -2976,6 +2977,77 @@ std::vector<array> Scatter::jvp(
throw std::runtime_error("[scatter] JVP not yet implemented");
}
std::pair<std::vector<array>, std::vector<int>> Scatter::vmap(
const std::vector<array>& inputs_,
const std::vector<int>& vmap_axes) {
assert(inputs_.size() >= 2);
assert(inputs_.size() == vmap_axes.size());
auto inputs = inputs_;
auto scatter_axes = axes_;
int src_ax = vmap_axes[0];
auto vmap_ax_it = std::find_if(
vmap_axes.begin(), vmap_axes.end(), [](int a) { return a >= 0; });
auto vmap_ax = *vmap_ax_it;
if (vmap_ax >= 0) {
auto vmap_size = inputs[vmap_ax_it - vmap_axes.begin()].shape(vmap_ax);
if (src_ax < 0) {
src_ax = 0;
inputs[0] =
repeat(expand_dims(inputs[0], 0, stream()), vmap_size, 0, stream());
}
for (int i = 1; i < vmap_axes.size() - 1; ++i) {
// vmap axis for indices goes to 0
if (vmap_axes[i] >= 0) {
inputs[i] = moveaxis(inputs[i], vmap_axes[i], 0, stream());
}
// insert a vmap axis and repeat
if (vmap_axes[i] < 0) {
auto idx_shape = inputs[i].shape();
inputs[i] =
repeat(expand_dims(inputs[i], 0, stream()), vmap_size, 0, stream());
}
// Adjust non-vmapped index axes to account for the extra vmap dimension.
if (scatter_axes[i - 1] >= src_ax) {
scatter_axes[i - 1]++;
}
}
auto vmap_inds = arange(vmap_size, inputs[1].dtype(), stream());
auto vmap_inds_shape = std::vector<int>(inputs[1].ndim(), 1);
vmap_inds_shape[0] = vmap_inds.size();
vmap_inds = reshape(vmap_inds, std::move(vmap_inds_shape), stream());
inputs.insert(
inputs.end() - 1, broadcast_to(vmap_inds, inputs[1].shape(), stream()));
scatter_axes.push_back(src_ax);
// Clone updates along the vmap dimension so they can be applied to each
// source tensor in the vmap.
auto& updates = inputs.back();
if (vmap_axes.back() < 0) {
updates = expand_dims(
updates, {0, static_cast<int>(inputs[1].ndim())}, stream());
updates = repeat(updates, vmap_size, 0, stream());
} else {
updates =
expand_dims(updates, static_cast<int>(inputs[1].ndim()), stream());
updates = moveaxis(updates, vmap_axes.back(), 0, stream());
}
}
auto& shape = inputs[0].shape();
auto dtype = inputs[0].dtype();
auto out = array(
shape,
dtype,
std::make_shared<Scatter>(stream(), reduce_type_, scatter_axes),
std::move(inputs));
return {{out}, {src_ax}};
}
std::vector<array> Sigmoid::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,

View File

@@ -1713,7 +1713,9 @@ class Scatter : public UnaryPrimitive {
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP();
DEFINE_GRADS();
void print(std::ostream& os) override {
os << "Scatter";
switch (reduce_type_) {

View File

@@ -563,8 +563,10 @@ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
detail::InTracing in_tracing;
if (in_axes.size() != inputs.size()) {
throw std::invalid_argument(
"[vmap] The number of in axes must match the number of inputs.");
std::stringstream ss;
ss << "[vmap] The number of in axes (" << in_axes.size()
<< ") must match the number of inputs (" << inputs.size() << ").";
throw std::invalid_argument(ss.str());
}
// Some error checking and get the vmap axis size
@@ -620,8 +622,10 @@ std::vector<array> vmap_replace(
const std::vector<int>& in_axes,
const std::vector<int>& out_axes) {
if (out_axes.size() != s_outputs.size()) {
throw std::invalid_argument(
"[vmap] The number of out axes must match the number of outputs.");
std::stringstream msg;
msg << "[vmap] The number of out axes (" << out_axes.size()
<< ") must match the number of outputs (" << s_outputs.size() << ").";
throw std::invalid_argument(msg.str());
}
std::unordered_map<std::uintptr_t, std::pair<array, int>> tmap;