ExpandDims primitive (#1687)

* add squeeze primitive

* simplify squeeze, use in gather

* fix

* fix

* fix

* fix

* fix no cpu

* use squeeze in matmul and friends

* expand dims primitive

* comment
This commit is contained in:
Awni Hannun
2024-12-10 16:39:07 -08:00
committed by GitHub
parent 310ad8d9db
commit f76a49e555
13 changed files with 373 additions and 184 deletions

View File

@@ -1602,6 +1602,55 @@ std::pair<std::vector<array>, std::vector<int>> Expm1::vmap(
return {{expm1(inputs[0], stream())}, axes};
}
std::vector<array> ExpandDims::vjp(
const std::vector<array>&,
const std::vector<array>& cotangents,
const std::vector<int>&,
const std::vector<array>&) {
return {squeeze(cotangents[0], axes_, stream())};
}
std::vector<array> ExpandDims::jvp(
const std::vector<array>&,
const std::vector<array>& tangents,
const std::vector<int>&) {
return {expand_dims(tangents[0], axes_, stream())};
}
std::pair<std::vector<array>, std::vector<int>> ExpandDims::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto ax = axes[0];
auto expand_axes = axes_;
for (auto& s : expand_axes) {
if (s >= axes[0]) {
s++;
} else {
ax++;
}
}
return {{expand_dims(inputs[0], std::move(expand_axes), stream())}, {ax}};
}
bool ExpandDims::is_equivalent(const Primitive& other) const {
const ExpandDims& a_other = static_cast<const ExpandDims&>(other);
return (axes_ == a_other.axes_);
}
Shape ExpandDims::output_shape(
const array& input,
const std::vector<int>& axes) {
auto shape = input.shape();
for (auto ax : axes) {
shape.insert(shape.begin() + ax, 1);
}
return shape;
}
std::vector<Shape> ExpandDims::output_shapes(const std::vector<array>& inputs) {
return {ExpandDims::output_shape(inputs[0], axes_)};
}
bool FFT::is_equivalent(const Primitive& other) const {
const FFT& r_other = static_cast<const FFT&>(other);
return axes_ == r_other.axes_ && inverse_ == r_other.inverse_ &&
@@ -1846,7 +1895,7 @@ bool Gather::is_equivalent(const Primitive& other) const {
std::vector<Shape> Gather::output_shapes(const std::vector<array>& inputs) {
Shape out_shape;
if (inputs.size() > 1) {
out_shape = inputs[0].shape();
out_shape = inputs[1].shape();
}
out_shape.insert(out_shape.end(), slice_sizes_.begin(), slice_sizes_.end());
return {std::move(out_shape)};
@@ -3847,6 +3896,57 @@ std::pair<std::vector<array>, std::vector<int>> Subtract::vmap(
return {{subtract(a, b, stream())}, {to_ax}};
}
std::vector<array> Squeeze::vjp(
const std::vector<array>&,
const std::vector<array>& cotangents,
const std::vector<int>&,
const std::vector<array>&) {
return {expand_dims(cotangents[0], axes_, stream())};
}
std::vector<array> Squeeze::jvp(
const std::vector<array>&,
const std::vector<array>& tangents,
const std::vector<int>&) {
return {squeeze(tangents[0], axes_, stream())};
}
std::pair<std::vector<array>, std::vector<int>> Squeeze::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto ax = axes[0];
auto squeeze_axes = axes_;
for (auto& s : squeeze_axes) {
if (s >= axes[0]) {
s++;
} else {
ax--;
}
}
return {{squeeze(inputs[0], std::move(squeeze_axes), stream())}, {ax}};
}
bool Squeeze::is_equivalent(const Primitive& other) const {
const Squeeze& a_other = static_cast<const Squeeze&>(other);
return (axes_ == a_other.axes_);
}
Shape Squeeze::output_shape(const array& input, const std::vector<int>& axes) {
Shape shape;
for (int i = 0, j = 0; i < input.ndim(); ++i) {
if (j < axes.size() && i == axes[j]) {
j++;
} else {
shape.push_back(input.shape(i));
}
}
return shape;
}
std::vector<Shape> Squeeze::output_shapes(const std::vector<array>& inputs) {
return {Squeeze::output_shape(inputs[0], axes_)};
}
std::vector<array> Tan::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,