Flatten and unflatten (#1692)

* flatten and unflatten

* fix grad

* fix shape infer

* use squeeze + unsqueeze in get_item
This commit is contained in:
Awni Hannun
2024-12-11 21:51:37 -08:00
committed by GitHub
parent 0bf19037ca
commit 4e1e9520e1
19 changed files with 363 additions and 93 deletions

View File

@@ -405,22 +405,22 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
// Unsqueeze handling
if (unsqueeze_needed || squeeze_needed) {
std::vector<int> out_shape;
int axis = 0;
for (auto& idx : remaining_indices) {
std::vector<int> squeeze_axes;
std::vector<int> unsqueeze_axes;
for (int axis = 0; axis < remaining_indices.size(); ++axis) {
auto& idx = remaining_indices[axis];
if (unsqueeze_needed && idx.is_none()) {
out_shape.push_back(1);
unsqueeze_axes.push_back(axis - squeeze_axes.size());
} else if (squeeze_needed && nb::isinstance<nb::int_>(idx)) {
axis++;
} else {
out_shape.push_back(src.shape(axis++));
squeeze_axes.push_back(axis - unsqueeze_axes.size());
}
}
out_shape.insert(
out_shape.end(), src.shape().begin() + axis, src.shape().end());
src = reshape(src, out_shape);
if (!squeeze_axes.empty()) {
src = squeeze(src, std::move(squeeze_axes));
}
if (!unsqueeze_axes.empty()) {
src = expand_dims(src, std::move(unsqueeze_axes));
}
}
return src;

View File

@@ -103,6 +103,36 @@ void init_ops(nb::module_& m) {
>>> mx.flatten(a, start_axis=0, end_axis=-1)
array([1, 2, 3, 4], dtype=int32)
)pbdoc");
m.def(
"unflatten",
&unflatten,
nb::arg(),
"axis"_a,
"shape"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def unflatten(a: array, /, axis: int, shape: Sequence[int], *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Unflatten an axis of an array to a shape.
Args:
a (array): Input array.
axis (int): The axis to unflatten.
shape (tuple(int)): The shape to unflatten to. At most one
entry can be ``-1`` in which case the corresponding size will be
inferred.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The unflattened array.
Example:
>>> a = mx.array([1, 2, 3, 4])
>>> mx.unflatten(a, 0, (2, -1))
array([[1, 2], [3, 4]], dtype=int32)
)pbdoc");
m.def(
"squeeze",
[](const mx::array& a, const IntOrVec& v, const mx::StreamOrDevice& s) {