mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add vmap to scatter (#1200)
* Add vmap to scatter * updates * vmap updates + a few more tests * bug fix --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
@@ -2976,6 +2977,77 @@ std::vector<array> Scatter::jvp(
|
||||
throw std::runtime_error("[scatter] JVP not yet implemented");
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Scatter::vmap(
|
||||
const std::vector<array>& inputs_,
|
||||
const std::vector<int>& vmap_axes) {
|
||||
assert(inputs_.size() >= 2);
|
||||
assert(inputs_.size() == vmap_axes.size());
|
||||
|
||||
auto inputs = inputs_;
|
||||
|
||||
auto scatter_axes = axes_;
|
||||
int src_ax = vmap_axes[0];
|
||||
|
||||
auto vmap_ax_it = std::find_if(
|
||||
vmap_axes.begin(), vmap_axes.end(), [](int a) { return a >= 0; });
|
||||
auto vmap_ax = *vmap_ax_it;
|
||||
if (vmap_ax >= 0) {
|
||||
auto vmap_size = inputs[vmap_ax_it - vmap_axes.begin()].shape(vmap_ax);
|
||||
if (src_ax < 0) {
|
||||
src_ax = 0;
|
||||
inputs[0] =
|
||||
repeat(expand_dims(inputs[0], 0, stream()), vmap_size, 0, stream());
|
||||
}
|
||||
for (int i = 1; i < vmap_axes.size() - 1; ++i) {
|
||||
// vmap axis for indices goes to 0
|
||||
if (vmap_axes[i] >= 0) {
|
||||
inputs[i] = moveaxis(inputs[i], vmap_axes[i], 0, stream());
|
||||
}
|
||||
// insert a vmap axis and repeat
|
||||
if (vmap_axes[i] < 0) {
|
||||
auto idx_shape = inputs[i].shape();
|
||||
inputs[i] =
|
||||
repeat(expand_dims(inputs[i], 0, stream()), vmap_size, 0, stream());
|
||||
}
|
||||
// Adjust non-vmapped index axes to account for the extra vmap dimension.
|
||||
if (scatter_axes[i - 1] >= src_ax) {
|
||||
scatter_axes[i - 1]++;
|
||||
}
|
||||
}
|
||||
|
||||
auto vmap_inds = arange(vmap_size, inputs[1].dtype(), stream());
|
||||
auto vmap_inds_shape = std::vector<int>(inputs[1].ndim(), 1);
|
||||
vmap_inds_shape[0] = vmap_inds.size();
|
||||
vmap_inds = reshape(vmap_inds, std::move(vmap_inds_shape), stream());
|
||||
inputs.insert(
|
||||
inputs.end() - 1, broadcast_to(vmap_inds, inputs[1].shape(), stream()));
|
||||
scatter_axes.push_back(src_ax);
|
||||
|
||||
// Clone updates along the vmap dimension so they can be applied to each
|
||||
// source tensor in the vmap.
|
||||
auto& updates = inputs.back();
|
||||
if (vmap_axes.back() < 0) {
|
||||
updates = expand_dims(
|
||||
updates, {0, static_cast<int>(inputs[1].ndim())}, stream());
|
||||
updates = repeat(updates, vmap_size, 0, stream());
|
||||
} else {
|
||||
updates =
|
||||
expand_dims(updates, static_cast<int>(inputs[1].ndim()), stream());
|
||||
updates = moveaxis(updates, vmap_axes.back(), 0, stream());
|
||||
}
|
||||
}
|
||||
|
||||
auto& shape = inputs[0].shape();
|
||||
auto dtype = inputs[0].dtype();
|
||||
auto out = array(
|
||||
shape,
|
||||
dtype,
|
||||
std::make_shared<Scatter>(stream(), reduce_type_, scatter_axes),
|
||||
std::move(inputs));
|
||||
|
||||
return {{out}, {src_ax}};
|
||||
}
|
||||
|
||||
std::vector<array> Sigmoid::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
|
||||
Reference in New Issue
Block a user