mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
scatter axis + gather axis primitives (#1813)
* scatter axis + gather axis primitives * add transforms * comment
This commit is contained in:
@@ -2098,6 +2098,77 @@ bool Gather::is_equivalent(const Primitive& other) const {
|
||||
return axes_ == g_other.axes_ && slice_sizes_ == g_other.slice_sizes_;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> GatherAxis::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
bool vmap_in = axes[0] >= 0;
|
||||
bool vmap_idx = axes[1] >= 0;
|
||||
|
||||
auto in = inputs[0];
|
||||
auto idx = inputs[1];
|
||||
int out_ax;
|
||||
if (vmap_in && vmap_idx) {
|
||||
// reorder the vmap axes to the same location
|
||||
idx = moveaxis(idx, axes[1], axes[0], stream());
|
||||
out_ax = axes[0];
|
||||
} else if (vmap_in) {
|
||||
// expand just the indices dimension
|
||||
idx = expand_dims(idx, axes[0], stream());
|
||||
out_ax = axes[0];
|
||||
} else if (vmap_idx) {
|
||||
// expand just the input dimension
|
||||
in = expand_dims(in, axes[1], stream());
|
||||
out_ax = axes[1];
|
||||
} else {
|
||||
out_ax = -1;
|
||||
}
|
||||
int axis = (out_ax >= 0 && axis_ >= out_ax) ? axis_ + 1 : axis_;
|
||||
return {{take_along_axis(in, idx, axis, stream())}, {out_ax}};
|
||||
}
|
||||
|
||||
std::vector<array> GatherAxis::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
std::vector<array> vjps;
|
||||
for (int argnum : argnums) {
|
||||
if (argnum > 0) {
|
||||
// Grads w.r.t. indices are zero
|
||||
vjps.push_back(
|
||||
zeros(primals[argnum].shape(), primals[argnum].dtype(), stream()));
|
||||
} else {
|
||||
auto src = zeros_like(primals[0], stream());
|
||||
vjps.push_back(array(
|
||||
src.shape(),
|
||||
src.dtype(),
|
||||
std::make_shared<ScatterAxis>(stream(), ScatterAxis::Sum, axis_),
|
||||
{src, primals[1], cotangents[0]}));
|
||||
}
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
||||
std::vector<array> GatherAxis::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
if (argnums.size() > 1 || argnums[0] != 0) {
|
||||
throw std::invalid_argument(
|
||||
"[gather_axis] Cannot calculate JVP with respect to indices.");
|
||||
}
|
||||
return {take_along_axis(tangents[0], primals[1], axis_, stream())};
|
||||
}
|
||||
|
||||
std::vector<Shape> GatherAxis::output_shapes(const std::vector<array>& inputs) {
|
||||
return {inputs[1].shape()};
|
||||
}
|
||||
|
||||
bool GatherAxis::is_equivalent(const Primitive& other) const {
|
||||
auto& g_other = static_cast<const GatherAxis&>(other);
|
||||
return axis_ == g_other.axis_;
|
||||
}
|
||||
|
||||
std::vector<Shape> Gather::output_shapes(const std::vector<array>& inputs) {
|
||||
Shape out_shape;
|
||||
if (inputs.size() > 1) {
|
||||
@@ -3621,6 +3692,117 @@ std::pair<std::vector<array>, std::vector<int>> Scatter::vmap(
|
||||
return {{out}, {src_ax}};
|
||||
}
|
||||
|
||||
std::vector<array> ScatterAxis::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
const auto& indices = primals[1];
|
||||
const auto& updates = primals[2];
|
||||
|
||||
std::vector<array> vjps;
|
||||
for (auto num : argnums) {
|
||||
// Gradient wrt to the input array
|
||||
if (num == 0) {
|
||||
if (reduce_type_ == ScatterAxis::None) {
|
||||
// Scatter 0s to the locations that were updated with the updates
|
||||
vjps.push_back(put_along_axis(
|
||||
cotangents[0],
|
||||
indices,
|
||||
zeros_like(updates, stream()),
|
||||
axis_,
|
||||
stream()));
|
||||
} else {
|
||||
// The input array values are kept so they all get gradients
|
||||
vjps.push_back(cotangents[0]);
|
||||
}
|
||||
} else if (num == 2) {
|
||||
vjps.push_back(take_along_axis(cotangents[0], indices, axis_, stream()));
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[scatter_axis] Cannot calculate VJP with respect to indices.");
|
||||
}
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
||||
std::vector<array> ScatterAxis::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
for (auto arg : argnums) {
|
||||
if (arg == 1) {
|
||||
throw std::invalid_argument(
|
||||
"[scatter_axis] Cannot calculate JVP with respect to indices.");
|
||||
}
|
||||
}
|
||||
if (argnums.size() == 2) {
|
||||
return {array(
|
||||
primals[0].shape(),
|
||||
primals[0].dtype(),
|
||||
std::make_shared<ScatterAxis>(stream(), reduce_type_, axis_),
|
||||
{tangents[0], primals[1], tangents[1]})};
|
||||
} else {
|
||||
auto tan_a =
|
||||
argnums[0] == 0 ? tangents[0] : zeros_like(primals[0], stream());
|
||||
auto tan_b =
|
||||
argnums[0] == 2 ? tangents[0] : zeros_like(primals[2], stream());
|
||||
return {array(
|
||||
primals[0].shape(),
|
||||
primals[0].dtype(),
|
||||
std::make_shared<ScatterAxis>(stream(), reduce_type_, axis_),
|
||||
{tan_a, primals[1], tan_b})};
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> ScatterAxis::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
// Find the first vmap axis
|
||||
int out_ax = -1;
|
||||
for (auto ax : axes) {
|
||||
if (ax >= 0) {
|
||||
out_ax = ax;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (out_ax < 0) {
|
||||
return {
|
||||
{array(
|
||||
inputs[0].shape(),
|
||||
inputs[0].dtype(),
|
||||
std::make_shared<ScatterAxis>(stream(), reduce_type_, axis_),
|
||||
inputs)},
|
||||
{-1}};
|
||||
}
|
||||
|
||||
auto v_in = inputs;
|
||||
for (int i = 0; i < axes.size(); ++i) {
|
||||
if (axes[i] >= 0) {
|
||||
// if out_ax >= 0 move axis o/w set out_ax
|
||||
if (out_ax != axes[i]) {
|
||||
v_in[i] = moveaxis(v_in[i], axes[i], out_ax, stream());
|
||||
}
|
||||
} else {
|
||||
v_in[i] = expand_dims(v_in[i], out_ax, stream());
|
||||
}
|
||||
}
|
||||
int axis = axis_ >= out_ax ? axis_ + 1 : axis_;
|
||||
auto fn = reduce_type_ == Sum ? scatter_add_axis : put_along_axis;
|
||||
return {{fn(v_in[0], v_in[1], v_in[2], axis, stream())}, {out_ax}};
|
||||
}
|
||||
|
||||
std::vector<Shape> ScatterAxis::output_shapes(
|
||||
const std::vector<array>& inputs) {
|
||||
return {inputs[0].shape()};
|
||||
}
|
||||
|
||||
bool ScatterAxis::is_equivalent(const Primitive& other) const {
|
||||
auto& s_other = static_cast<const ScatterAxis&>(other);
|
||||
return reduce_type_ == s_other.reduce_type_ && axis_ == s_other.axis_;
|
||||
}
|
||||
|
||||
std::vector<array> Sigmoid::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
|
||||
Reference in New Issue
Block a user