Flatten and unflatten (#1692)

* flatten and unflatten

* fix grad

* fix shape infer

* use squeeze + unsqueeze in get_item
This commit is contained in:
Awni Hannun
2024-12-11 21:51:37 -08:00
committed by GitHub
parent 0bf19037ca
commit 4e1e9520e1
19 changed files with 363 additions and 93 deletions

View File

@@ -405,22 +405,22 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
// Unsqueeze handling
if (unsqueeze_needed || squeeze_needed) {
std::vector<int> out_shape;
int axis = 0;
for (auto& idx : remaining_indices) {
std::vector<int> squeeze_axes;
std::vector<int> unsqueeze_axes;
for (int axis = 0; axis < remaining_indices.size(); ++axis) {
auto& idx = remaining_indices[axis];
if (unsqueeze_needed && idx.is_none()) {
out_shape.push_back(1);
unsqueeze_axes.push_back(axis - squeeze_axes.size());
} else if (squeeze_needed && nb::isinstance<nb::int_>(idx)) {
axis++;
} else {
out_shape.push_back(src.shape(axis++));
squeeze_axes.push_back(axis - unsqueeze_axes.size());
}
}
out_shape.insert(
out_shape.end(), src.shape().begin() + axis, src.shape().end());
src = reshape(src, out_shape);
if (!squeeze_axes.empty()) {
src = squeeze(src, std::move(squeeze_axes));
}
if (!unsqueeze_axes.empty()) {
src = expand_dims(src, std::move(unsqueeze_axes));
}
}
return src;