mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Flatten and unflatten (#1692)
* flatten and unflatten * fix grad * fix shape infer * use squeeze + unsqueeze in get_item
This commit is contained in:
@@ -1651,12 +1651,114 @@ std::vector<Shape> ExpandDims::output_shapes(const std::vector<array>& inputs) {
|
||||
return {ExpandDims::output_shape(inputs[0], axes_)};
|
||||
}
|
||||
|
||||
std::vector<array> Flatten::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>&,
|
||||
const std::vector<array>&) {
|
||||
auto& in = primals[0];
|
||||
Shape unflatten_shape(
|
||||
in.shape().begin() + start_axis_, in.shape().begin() + end_axis_ + 1);
|
||||
return {unflatten(
|
||||
cotangents[0], start_axis_, std::move(unflatten_shape), stream())};
|
||||
}
|
||||
|
||||
std::vector<array> Flatten::jvp(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>&) {
|
||||
return {flatten(tangents[0], start_axis_, end_axis_, stream())};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Flatten::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto ax = axes[0];
|
||||
auto start_axis = start_axis_;
|
||||
auto end_axis = end_axis_;
|
||||
if (ax < start_axis) {
|
||||
start_axis++;
|
||||
end_axis++;
|
||||
} else {
|
||||
ax -= (end_axis - start_axis);
|
||||
}
|
||||
return {{flatten(inputs[0], start_axis, end_axis, stream())}, {ax}};
|
||||
}
|
||||
|
||||
bool Flatten::is_equivalent(const Primitive& other) const {
|
||||
const Flatten& a_other = static_cast<const Flatten&>(other);
|
||||
return start_axis_ == a_other.start_axis_ && end_axis_ == a_other.end_axis_;
|
||||
}
|
||||
|
||||
Shape Flatten::output_shape(const array& input, int start_axis, int end_axis) {
|
||||
Shape shape = input.shape();
|
||||
auto flat_size = input.shape(start_axis);
|
||||
for (int ax = start_axis + 1; ax <= end_axis; ++ax) {
|
||||
flat_size *= input.shape(ax);
|
||||
}
|
||||
shape.erase(shape.begin() + start_axis + 1, shape.begin() + end_axis + 1);
|
||||
shape[start_axis] = flat_size;
|
||||
return shape;
|
||||
}
|
||||
|
||||
std::vector<Shape> Flatten::output_shapes(const std::vector<array>& inputs) {
|
||||
return {Flatten::output_shape(inputs[0], start_axis_, end_axis_)};
|
||||
}
|
||||
|
||||
bool FFT::is_equivalent(const Primitive& other) const {
|
||||
const FFT& r_other = static_cast<const FFT&>(other);
|
||||
return axes_ == r_other.axes_ && inverse_ == r_other.inverse_ &&
|
||||
real_ == r_other.real_;
|
||||
}
|
||||
|
||||
std::vector<array> Unflatten::vjp(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>&,
|
||||
const std::vector<array>&) {
|
||||
return {flatten(cotangents[0], axis_, axis_ + shape_.size(), stream())};
|
||||
}
|
||||
|
||||
std::vector<array> Unflatten::jvp(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>&) {
|
||||
return {unflatten(tangents[0], axis_, shape_, stream())};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Unflatten::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto ax = axes[0];
|
||||
auto axis = axis_;
|
||||
if (ax <= axis_) {
|
||||
axis++;
|
||||
} else {
|
||||
ax += (shape_.size() - 1);
|
||||
}
|
||||
return {{unflatten(inputs[0], axis, shape_, stream())}, {ax}};
|
||||
}
|
||||
|
||||
bool Unflatten::is_equivalent(const Primitive& other) const {
|
||||
const auto& a_other = static_cast<const Unflatten&>(other);
|
||||
return axis_ == a_other.axis_ && shape_ == a_other.shape_;
|
||||
}
|
||||
|
||||
Shape Unflatten::output_shape(
|
||||
const array& input,
|
||||
int axis,
|
||||
const Shape& shape) {
|
||||
Shape out_shape = input.shape();
|
||||
out_shape[axis] = shape[0];
|
||||
out_shape.insert(
|
||||
out_shape.begin() + axis + 1, shape.begin() + 1, shape.end());
|
||||
return out_shape;
|
||||
}
|
||||
|
||||
std::vector<Shape> Unflatten::output_shapes(const std::vector<array>& inputs) {
|
||||
return {Unflatten::output_shape(inputs[0], axis_, shape_)};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> FFT::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
|
||||
Reference in New Issue
Block a user