mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Dynamic slicing (#1741)
* dynamic slice and slice update * python bindings + tests + fix set item * fix compile issue * comment * fix jit
This commit is contained in:
@@ -3715,22 +3715,18 @@ std::vector<array> SliceUpdate::vjp(
|
||||
for (int num : argnums) {
|
||||
// Vjp for source
|
||||
if (num == 0) {
|
||||
auto grad = slice_update(
|
||||
vjps.push_back(slice_update(
|
||||
cotan,
|
||||
zeros_like(upd, stream()),
|
||||
start_indices_,
|
||||
end_indices_,
|
||||
strides_,
|
||||
stream());
|
||||
|
||||
vjps.push_back(grad);
|
||||
stream()));
|
||||
}
|
||||
// Vjp fpr updates
|
||||
else {
|
||||
auto grad =
|
||||
slice(cotan, start_indices_, end_indices_, strides_, stream());
|
||||
|
||||
vjps.push_back(grad);
|
||||
vjps.push_back(
|
||||
slice(cotan, start_indices_, end_indices_, strides_, stream()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3753,12 +3749,153 @@ std::vector<array> SliceUpdate::jvp(
|
||||
}
|
||||
|
||||
bool SliceUpdate::is_equivalent(const Primitive& other) const {
|
||||
const SliceUpdate& s_other = static_cast<const SliceUpdate&>(other);
|
||||
const auto& s_other = static_cast<const SliceUpdate&>(other);
|
||||
return (
|
||||
start_indices_ == s_other.start_indices_ &&
|
||||
end_indices_ == s_other.end_indices_ && strides_ == s_other.strides_);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> DynamicSlice::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto& in = inputs[0];
|
||||
auto& start = inputs[1];
|
||||
auto vax = axes[0];
|
||||
if (axes[1] >= 0) {
|
||||
throw std::invalid_argument(
|
||||
"[DynamicSlice::vmap] vmap over start indices not yet supported.");
|
||||
}
|
||||
auto slice_size = slice_size_;
|
||||
auto slice_axes = axes_;
|
||||
if (vax >= 0) {
|
||||
for (auto& ax : slice_axes) {
|
||||
if (ax >= vax) {
|
||||
ax++;
|
||||
}
|
||||
}
|
||||
slice_size.insert(slice_size.begin() + vax, in.shape(vax));
|
||||
}
|
||||
return {
|
||||
{slice(
|
||||
in, start, std::move(slice_axes), std::move(slice_size), stream())},
|
||||
{vax}};
|
||||
}
|
||||
|
||||
std::vector<array> DynamicSlice::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
if (argnums[0] == 1 || argnums.size() > 1) {
|
||||
throw std::invalid_argument(
|
||||
"[DynamicSlice::vjp] Not supported for start indices.");
|
||||
}
|
||||
auto out = zeros_like(primals[0], stream());
|
||||
return {slice_update(out, cotangents[0], primals[1], axes_, stream())};
|
||||
}
|
||||
|
||||
std::vector<array> DynamicSlice::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>&) {
|
||||
return {slice(tangents[0], primals[1], axes_, slice_size_, stream())};
|
||||
}
|
||||
|
||||
bool DynamicSlice::is_equivalent(const Primitive& other) const {
|
||||
const auto& s_other = static_cast<const DynamicSlice&>(other);
|
||||
return (axes_ == s_other.axes_ && slice_size_ == s_other.slice_size_);
|
||||
}
|
||||
|
||||
std::vector<Shape> DynamicSlice::output_shapes(const std::vector<array>&) {
|
||||
return {slice_size_};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> DynamicSliceUpdate::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto src = inputs[0];
|
||||
auto upd = inputs[1];
|
||||
auto& start = inputs[2];
|
||||
auto src_ax = axes[0];
|
||||
auto upd_ax = axes[1];
|
||||
if (axes[2] >= 0) {
|
||||
throw std::runtime_error(
|
||||
"[DynamicSliceUpdate::vmap] vmap over start indices not yet supported.");
|
||||
}
|
||||
// No vmapping needed
|
||||
if (src_ax == -1 && upd_ax == -1) {
|
||||
return {{slice_update(src, upd, start, axes_, stream())}, {-1}};
|
||||
}
|
||||
|
||||
// Broadcast src
|
||||
if (src_ax == -1) {
|
||||
src = expand_dims(src, upd_ax, stream());
|
||||
auto shape = src.shape();
|
||||
shape[upd_ax] = upd.shape(upd_ax);
|
||||
src = broadcast_to(src, shape, stream());
|
||||
src_ax = upd_ax;
|
||||
}
|
||||
|
||||
// Broadcast upd
|
||||
if (upd_ax == -1) {
|
||||
upd = expand_dims(upd, src_ax, stream());
|
||||
upd_ax = src_ax;
|
||||
}
|
||||
|
||||
if (src_ax != upd_ax) {
|
||||
upd = moveaxis(upd, upd_ax, src_ax, stream());
|
||||
}
|
||||
|
||||
auto slice_axes = axes_;
|
||||
for (auto& ax : slice_axes) {
|
||||
if (ax >= src_ax) {
|
||||
ax++;
|
||||
}
|
||||
}
|
||||
return {
|
||||
{slice_update(src, upd, start, std::move(slice_axes), stream())},
|
||||
{src_ax}};
|
||||
}
|
||||
|
||||
std::vector<array> DynamicSliceUpdate::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
auto& cotan = cotangents[0];
|
||||
auto& upd = primals[1];
|
||||
auto& start = primals[2];
|
||||
|
||||
std::vector<array> vjps;
|
||||
|
||||
for (int num : argnums) {
|
||||
if (num == 0) {
|
||||
// Vjp for source
|
||||
vjps.push_back(slice_update(
|
||||
cotan, zeros_like(upd, stream()), start, axes_, stream()));
|
||||
} else if (num == 1) {
|
||||
// Vjp fpr updates
|
||||
vjps.push_back(slice(cotan, start, axes_, upd.shape(), stream()));
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[DynamicSliceUpdate::vjp] Not supported for start indices");
|
||||
}
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
||||
std::vector<array> DynamicSliceUpdate::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>&) {
|
||||
return {slice_update(tangents[0], tangents[1], primals[2], axes_, stream())};
|
||||
}
|
||||
|
||||
bool DynamicSliceUpdate::is_equivalent(const Primitive& other) const {
|
||||
const auto& s_other = static_cast<const DynamicSliceUpdate&>(other);
|
||||
return axes_ == s_other.axes_;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Softmax::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
|
||||
Reference in New Issue
Block a user