Dynamic slicing (#1741)

* dynamic slice and slice update

* python bindings + tests + fix set item

* fix compile issue

* comment

* fix jit
This commit is contained in:
Awni Hannun
2025-01-07 14:02:16 -08:00
committed by GitHub
parent c9c81d0584
commit 516ded618b
27 changed files with 941 additions and 75 deletions

View File

@@ -3715,22 +3715,18 @@ std::vector<array> SliceUpdate::vjp(
for (int num : argnums) {
// Vjp for source
if (num == 0) {
auto grad = slice_update(
vjps.push_back(slice_update(
cotan,
zeros_like(upd, stream()),
start_indices_,
end_indices_,
strides_,
stream());
vjps.push_back(grad);
stream()));
}
// Vjp fpr updates
else {
auto grad =
slice(cotan, start_indices_, end_indices_, strides_, stream());
vjps.push_back(grad);
vjps.push_back(
slice(cotan, start_indices_, end_indices_, strides_, stream()));
}
}
@@ -3753,12 +3749,153 @@ std::vector<array> SliceUpdate::jvp(
}
bool SliceUpdate::is_equivalent(const Primitive& other) const {
const SliceUpdate& s_other = static_cast<const SliceUpdate&>(other);
const auto& s_other = static_cast<const SliceUpdate&>(other);
return (
start_indices_ == s_other.start_indices_ &&
end_indices_ == s_other.end_indices_ && strides_ == s_other.strides_);
}
std::pair<std::vector<array>, std::vector<int>> DynamicSlice::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto& in = inputs[0];
auto& start = inputs[1];
auto vax = axes[0];
if (axes[1] >= 0) {
throw std::invalid_argument(
"[DynamicSlice::vmap] vmap over start indices not yet supported.");
}
auto slice_size = slice_size_;
auto slice_axes = axes_;
if (vax >= 0) {
for (auto& ax : slice_axes) {
if (ax >= vax) {
ax++;
}
}
slice_size.insert(slice_size.begin() + vax, in.shape(vax));
}
return {
{slice(
in, start, std::move(slice_axes), std::move(slice_size), stream())},
{vax}};
}
std::vector<array> DynamicSlice::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
if (argnums[0] == 1 || argnums.size() > 1) {
throw std::invalid_argument(
"[DynamicSlice::vjp] Not supported for start indices.");
}
auto out = zeros_like(primals[0], stream());
return {slice_update(out, cotangents[0], primals[1], axes_, stream())};
}
std::vector<array> DynamicSlice::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>&) {
return {slice(tangents[0], primals[1], axes_, slice_size_, stream())};
}
bool DynamicSlice::is_equivalent(const Primitive& other) const {
const auto& s_other = static_cast<const DynamicSlice&>(other);
return (axes_ == s_other.axes_ && slice_size_ == s_other.slice_size_);
}
std::vector<Shape> DynamicSlice::output_shapes(const std::vector<array>&) {
return {slice_size_};
}
std::pair<std::vector<array>, std::vector<int>> DynamicSliceUpdate::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto src = inputs[0];
auto upd = inputs[1];
auto& start = inputs[2];
auto src_ax = axes[0];
auto upd_ax = axes[1];
if (axes[2] >= 0) {
throw std::runtime_error(
"[DynamicSliceUpdate::vmap] vmap over start indices not yet supported.");
}
// No vmapping needed
if (src_ax == -1 && upd_ax == -1) {
return {{slice_update(src, upd, start, axes_, stream())}, {-1}};
}
// Broadcast src
if (src_ax == -1) {
src = expand_dims(src, upd_ax, stream());
auto shape = src.shape();
shape[upd_ax] = upd.shape(upd_ax);
src = broadcast_to(src, shape, stream());
src_ax = upd_ax;
}
// Broadcast upd
if (upd_ax == -1) {
upd = expand_dims(upd, src_ax, stream());
upd_ax = src_ax;
}
if (src_ax != upd_ax) {
upd = moveaxis(upd, upd_ax, src_ax, stream());
}
auto slice_axes = axes_;
for (auto& ax : slice_axes) {
if (ax >= src_ax) {
ax++;
}
}
return {
{slice_update(src, upd, start, std::move(slice_axes), stream())},
{src_ax}};
}
std::vector<array> DynamicSliceUpdate::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
auto& cotan = cotangents[0];
auto& upd = primals[1];
auto& start = primals[2];
std::vector<array> vjps;
for (int num : argnums) {
if (num == 0) {
// Vjp for source
vjps.push_back(slice_update(
cotan, zeros_like(upd, stream()), start, axes_, stream()));
} else if (num == 1) {
// Vjp fpr updates
vjps.push_back(slice(cotan, start, axes_, upd.shape(), stream()));
} else {
throw std::invalid_argument(
"[DynamicSliceUpdate::vjp] Not supported for start indices");
}
}
return vjps;
}
std::vector<array> DynamicSliceUpdate::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>&) {
return {slice_update(tangents[0], tangents[1], primals[2], axes_, stream())};
}
bool DynamicSliceUpdate::is_equivalent(const Primitive& other) const {
const auto& s_other = static_cast<const DynamicSliceUpdate&>(other);
return axes_ == s_other.axes_;
}
std::pair<std::vector<array>, std::vector<int>> Softmax::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {