scatter axis + gather axis primitives (#1813)

* scatter axis + gather axis primitives

* add transforms

* comment
This commit is contained in:
Awni Hannun
2025-01-31 20:48:08 -08:00
committed by GitHub
parent c6fc07f1f4
commit b7c9f1d38f
15 changed files with 1037 additions and 85 deletions

View File

@@ -2098,6 +2098,77 @@ bool Gather::is_equivalent(const Primitive& other) const {
return axes_ == g_other.axes_ && slice_sizes_ == g_other.slice_sizes_;
}
std::pair<std::vector<array>, std::vector<int>> GatherAxis::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
bool vmap_in = axes[0] >= 0;
bool vmap_idx = axes[1] >= 0;
auto in = inputs[0];
auto idx = inputs[1];
int out_ax;
if (vmap_in && vmap_idx) {
// reorder the vmap axes to the same location
idx = moveaxis(idx, axes[1], axes[0], stream());
out_ax = axes[0];
} else if (vmap_in) {
// expand just the indices dimension
idx = expand_dims(idx, axes[0], stream());
out_ax = axes[0];
} else if (vmap_idx) {
// expand just the input dimension
in = expand_dims(in, axes[1], stream());
out_ax = axes[1];
} else {
out_ax = -1;
}
int axis = (out_ax >= 0 && axis_ >= out_ax) ? axis_ + 1 : axis_;
return {{take_along_axis(in, idx, axis, stream())}, {out_ax}};
}
std::vector<array> GatherAxis::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
std::vector<array> vjps;
for (int argnum : argnums) {
if (argnum > 0) {
// Grads w.r.t. indices are zero
vjps.push_back(
zeros(primals[argnum].shape(), primals[argnum].dtype(), stream()));
} else {
auto src = zeros_like(primals[0], stream());
vjps.push_back(array(
src.shape(),
src.dtype(),
std::make_shared<ScatterAxis>(stream(), ScatterAxis::Sum, axis_),
{src, primals[1], cotangents[0]}));
}
}
return vjps;
}
std::vector<array> GatherAxis::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
if (argnums.size() > 1 || argnums[0] != 0) {
throw std::invalid_argument(
"[gather_axis] Cannot calculate JVP with respect to indices.");
}
return {take_along_axis(tangents[0], primals[1], axis_, stream())};
}
std::vector<Shape> GatherAxis::output_shapes(const std::vector<array>& inputs) {
return {inputs[1].shape()};
}
bool GatherAxis::is_equivalent(const Primitive& other) const {
auto& g_other = static_cast<const GatherAxis&>(other);
return axis_ == g_other.axis_;
}
std::vector<Shape> Gather::output_shapes(const std::vector<array>& inputs) {
Shape out_shape;
if (inputs.size() > 1) {
@@ -3621,6 +3692,117 @@ std::pair<std::vector<array>, std::vector<int>> Scatter::vmap(
return {{out}, {src_ax}};
}
std::vector<array> ScatterAxis::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const auto& indices = primals[1];
const auto& updates = primals[2];
std::vector<array> vjps;
for (auto num : argnums) {
// Gradient wrt to the input array
if (num == 0) {
if (reduce_type_ == ScatterAxis::None) {
// Scatter 0s to the locations that were updated with the updates
vjps.push_back(put_along_axis(
cotangents[0],
indices,
zeros_like(updates, stream()),
axis_,
stream()));
} else {
// The input array values are kept so they all get gradients
vjps.push_back(cotangents[0]);
}
} else if (num == 2) {
vjps.push_back(take_along_axis(cotangents[0], indices, axis_, stream()));
} else {
throw std::invalid_argument(
"[scatter_axis] Cannot calculate VJP with respect to indices.");
}
}
return vjps;
}
std::vector<array> ScatterAxis::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
for (auto arg : argnums) {
if (arg == 1) {
throw std::invalid_argument(
"[scatter_axis] Cannot calculate JVP with respect to indices.");
}
}
if (argnums.size() == 2) {
return {array(
primals[0].shape(),
primals[0].dtype(),
std::make_shared<ScatterAxis>(stream(), reduce_type_, axis_),
{tangents[0], primals[1], tangents[1]})};
} else {
auto tan_a =
argnums[0] == 0 ? tangents[0] : zeros_like(primals[0], stream());
auto tan_b =
argnums[0] == 2 ? tangents[0] : zeros_like(primals[2], stream());
return {array(
primals[0].shape(),
primals[0].dtype(),
std::make_shared<ScatterAxis>(stream(), reduce_type_, axis_),
{tan_a, primals[1], tan_b})};
}
}
std::pair<std::vector<array>, std::vector<int>> ScatterAxis::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
// Find the first vmap axis
int out_ax = -1;
for (auto ax : axes) {
if (ax >= 0) {
out_ax = ax;
break;
}
}
if (out_ax < 0) {
return {
{array(
inputs[0].shape(),
inputs[0].dtype(),
std::make_shared<ScatterAxis>(stream(), reduce_type_, axis_),
inputs)},
{-1}};
}
auto v_in = inputs;
for (int i = 0; i < axes.size(); ++i) {
if (axes[i] >= 0) {
// if out_ax >= 0 move axis o/w set out_ax
if (out_ax != axes[i]) {
v_in[i] = moveaxis(v_in[i], axes[i], out_ax, stream());
}
} else {
v_in[i] = expand_dims(v_in[i], out_ax, stream());
}
}
int axis = axis_ >= out_ax ? axis_ + 1 : axis_;
auto fn = reduce_type_ == Sum ? scatter_add_axis : put_along_axis;
return {{fn(v_in[0], v_in[1], v_in[2], axis, stream())}, {out_ax}};
}
std::vector<Shape> ScatterAxis::output_shapes(
const std::vector<array>& inputs) {
return {inputs[0].shape()};
}
bool ScatterAxis::is_equivalent(const Primitive& other) const {
auto& s_other = static_cast<const ScatterAxis&>(other);
return reduce_type_ == s_other.reduce_type_ && axis_ == s_other.axis_;
}
std::vector<array> Sigmoid::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,