Flatten and unflatten (#1692)

* flatten and unflatten

* fix grad

* fix shape infer

* use squeeze + unsqueeze in get_item
This commit is contained in:
Awni Hannun
2024-12-11 21:51:37 -08:00
committed by GitHub
parent 0bf19037ca
commit 4e1e9520e1
19 changed files with 363 additions and 93 deletions

View File

@@ -1651,12 +1651,114 @@ std::vector<Shape> ExpandDims::output_shapes(const std::vector<array>& inputs) {
return {ExpandDims::output_shape(inputs[0], axes_)};
}
std::vector<array> Flatten::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>&,
const std::vector<array>&) {
auto& in = primals[0];
Shape unflatten_shape(
in.shape().begin() + start_axis_, in.shape().begin() + end_axis_ + 1);
return {unflatten(
cotangents[0], start_axis_, std::move(unflatten_shape), stream())};
}
std::vector<array> Flatten::jvp(
const std::vector<array>&,
const std::vector<array>& tangents,
const std::vector<int>&) {
return {flatten(tangents[0], start_axis_, end_axis_, stream())};
}
std::pair<std::vector<array>, std::vector<int>> Flatten::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto ax = axes[0];
auto start_axis = start_axis_;
auto end_axis = end_axis_;
if (ax < start_axis) {
start_axis++;
end_axis++;
} else {
ax -= (end_axis - start_axis);
}
return {{flatten(inputs[0], start_axis, end_axis, stream())}, {ax}};
}
bool Flatten::is_equivalent(const Primitive& other) const {
const Flatten& a_other = static_cast<const Flatten&>(other);
return start_axis_ == a_other.start_axis_ && end_axis_ == a_other.end_axis_;
}
Shape Flatten::output_shape(const array& input, int start_axis, int end_axis) {
Shape shape = input.shape();
auto flat_size = input.shape(start_axis);
for (int ax = start_axis + 1; ax <= end_axis; ++ax) {
flat_size *= input.shape(ax);
}
shape.erase(shape.begin() + start_axis + 1, shape.begin() + end_axis + 1);
shape[start_axis] = flat_size;
return shape;
}
std::vector<Shape> Flatten::output_shapes(const std::vector<array>& inputs) {
return {Flatten::output_shape(inputs[0], start_axis_, end_axis_)};
}
bool FFT::is_equivalent(const Primitive& other) const {
const FFT& r_other = static_cast<const FFT&>(other);
return axes_ == r_other.axes_ && inverse_ == r_other.inverse_ &&
real_ == r_other.real_;
}
std::vector<array> Unflatten::vjp(
const std::vector<array>&,
const std::vector<array>& cotangents,
const std::vector<int>&,
const std::vector<array>&) {
return {flatten(cotangents[0], axis_, axis_ + shape_.size(), stream())};
}
std::vector<array> Unflatten::jvp(
const std::vector<array>&,
const std::vector<array>& tangents,
const std::vector<int>&) {
return {unflatten(tangents[0], axis_, shape_, stream())};
}
std::pair<std::vector<array>, std::vector<int>> Unflatten::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto ax = axes[0];
auto axis = axis_;
if (ax <= axis_) {
axis++;
} else {
ax += (shape_.size() - 1);
}
return {{unflatten(inputs[0], axis, shape_, stream())}, {ax}};
}
bool Unflatten::is_equivalent(const Primitive& other) const {
const auto& a_other = static_cast<const Unflatten&>(other);
return axis_ == a_other.axis_ && shape_ == a_other.shape_;
}
Shape Unflatten::output_shape(
const array& input,
int axis,
const Shape& shape) {
Shape out_shape = input.shape();
out_shape[axis] = shape[0];
out_shape.insert(
out_shape.begin() + axis + 1, shape.begin() + 1, shape.end());
return out_shape;
}
std::vector<Shape> Unflatten::output_shapes(const std::vector<array>& inputs) {
return {Unflatten::output_shape(inputs[0], axis_, shape_)};
}
std::pair<std::vector<array>, std::vector<int>> FFT::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {