mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Put along axis + fixe for partition grad (#1430)
* put along axis, fixes for partition grad * zeros for arg reduce
This commit is contained in:
@@ -471,6 +471,21 @@ std::pair<std::vector<array>, std::vector<int>> ArgPartition::vmap(
|
||||
return {{argpartition(inputs[0], axis_ + axis_left, stream())}, axes};
|
||||
}
|
||||
|
||||
std::vector<array> ArgPartition::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&,
|
||||
const std::vector<array>&) {
|
||||
return {zeros_like(primals[0], stream())};
|
||||
}
|
||||
|
||||
std::vector<array> ArgPartition::jvp(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>&) {
|
||||
return {zeros_like(tangents[0], stream())};
|
||||
}
|
||||
|
||||
bool ArgPartition::is_equivalent(const Primitive& other) const {
|
||||
const ArgPartition& r_other = static_cast<const ArgPartition&>(other);
|
||||
return axis_ == r_other.axis_ && kth_ == r_other.kth_;
|
||||
@@ -495,6 +510,21 @@ std::pair<std::vector<array>, std::vector<int>> ArgReduce::vmap(
|
||||
return {out, axes};
|
||||
}
|
||||
|
||||
std::vector<array> ArgReduce::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&,
|
||||
const std::vector<array>&) {
|
||||
return {zeros_like(primals[0], stream())};
|
||||
}
|
||||
|
||||
std::vector<array> ArgReduce::jvp(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>&) {
|
||||
return {zeros_like(tangents[0], stream())};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> ArgSort::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
@@ -2336,7 +2366,13 @@ std::vector<array> Partition::vjp(
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
return jvp(primals, cotangents, argnums);
|
||||
auto sort_idx = argpartition(primals[0], kth_, axis_, stream());
|
||||
return {put_along_axis(
|
||||
zeros_like(primals[0], stream()),
|
||||
sort_idx,
|
||||
cotangents[0],
|
||||
axis_,
|
||||
stream())};
|
||||
}
|
||||
|
||||
std::vector<array> Partition::jvp(
|
||||
|
||||
Reference in New Issue
Block a user