mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
ExpandDims primitive (#1687)
* add squeeze primitive * simplify squeeze, use in gather * fix * fix * fix * fix * fix no cpu * use squeeze in matmul and friends * expand dims primitive * comment
This commit is contained in:
@@ -1602,6 +1602,55 @@ std::pair<std::vector<array>, std::vector<int>> Expm1::vmap(
|
||||
return {{expm1(inputs[0], stream())}, axes};
|
||||
}
|
||||
|
||||
std::vector<array> ExpandDims::vjp(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>&,
|
||||
const std::vector<array>&) {
|
||||
return {squeeze(cotangents[0], axes_, stream())};
|
||||
}
|
||||
|
||||
std::vector<array> ExpandDims::jvp(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>&) {
|
||||
return {expand_dims(tangents[0], axes_, stream())};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> ExpandDims::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto ax = axes[0];
|
||||
auto expand_axes = axes_;
|
||||
for (auto& s : expand_axes) {
|
||||
if (s >= axes[0]) {
|
||||
s++;
|
||||
} else {
|
||||
ax++;
|
||||
}
|
||||
}
|
||||
return {{expand_dims(inputs[0], std::move(expand_axes), stream())}, {ax}};
|
||||
}
|
||||
|
||||
bool ExpandDims::is_equivalent(const Primitive& other) const {
|
||||
const ExpandDims& a_other = static_cast<const ExpandDims&>(other);
|
||||
return (axes_ == a_other.axes_);
|
||||
}
|
||||
|
||||
Shape ExpandDims::output_shape(
|
||||
const array& input,
|
||||
const std::vector<int>& axes) {
|
||||
auto shape = input.shape();
|
||||
for (auto ax : axes) {
|
||||
shape.insert(shape.begin() + ax, 1);
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
std::vector<Shape> ExpandDims::output_shapes(const std::vector<array>& inputs) {
|
||||
return {ExpandDims::output_shape(inputs[0], axes_)};
|
||||
}
|
||||
|
||||
bool FFT::is_equivalent(const Primitive& other) const {
|
||||
const FFT& r_other = static_cast<const FFT&>(other);
|
||||
return axes_ == r_other.axes_ && inverse_ == r_other.inverse_ &&
|
||||
@@ -1846,7 +1895,7 @@ bool Gather::is_equivalent(const Primitive& other) const {
|
||||
std::vector<Shape> Gather::output_shapes(const std::vector<array>& inputs) {
|
||||
Shape out_shape;
|
||||
if (inputs.size() > 1) {
|
||||
out_shape = inputs[0].shape();
|
||||
out_shape = inputs[1].shape();
|
||||
}
|
||||
out_shape.insert(out_shape.end(), slice_sizes_.begin(), slice_sizes_.end());
|
||||
return {std::move(out_shape)};
|
||||
@@ -3847,6 +3896,57 @@ std::pair<std::vector<array>, std::vector<int>> Subtract::vmap(
|
||||
return {{subtract(a, b, stream())}, {to_ax}};
|
||||
}
|
||||
|
||||
std::vector<array> Squeeze::vjp(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>&,
|
||||
const std::vector<array>&) {
|
||||
return {expand_dims(cotangents[0], axes_, stream())};
|
||||
}
|
||||
|
||||
std::vector<array> Squeeze::jvp(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>&) {
|
||||
return {squeeze(tangents[0], axes_, stream())};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Squeeze::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto ax = axes[0];
|
||||
auto squeeze_axes = axes_;
|
||||
for (auto& s : squeeze_axes) {
|
||||
if (s >= axes[0]) {
|
||||
s++;
|
||||
} else {
|
||||
ax--;
|
||||
}
|
||||
}
|
||||
return {{squeeze(inputs[0], std::move(squeeze_axes), stream())}, {ax}};
|
||||
}
|
||||
|
||||
bool Squeeze::is_equivalent(const Primitive& other) const {
|
||||
const Squeeze& a_other = static_cast<const Squeeze&>(other);
|
||||
return (axes_ == a_other.axes_);
|
||||
}
|
||||
|
||||
Shape Squeeze::output_shape(const array& input, const std::vector<int>& axes) {
|
||||
Shape shape;
|
||||
for (int i = 0, j = 0; i < input.ndim(); ++i) {
|
||||
if (j < axes.size() && i == axes[j]) {
|
||||
j++;
|
||||
} else {
|
||||
shape.push_back(input.shape(i));
|
||||
}
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
std::vector<Shape> Squeeze::output_shapes(const std::vector<array>& inputs) {
|
||||
return {Squeeze::output_shape(inputs[0], axes_)};
|
||||
}
|
||||
|
||||
std::vector<array> Tan::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
|
||||
Reference in New Issue
Block a user