Flatten and unflatten (#1692)

* flatten and unflatten

* fix grad

* fix shape infer

* use squeeze + unsqueeze in get_item
This commit is contained in:
Awni Hannun
2024-12-11 21:51:37 -08:00
committed by GitHub
parent 0bf19037ca
commit 4e1e9520e1
19 changed files with 363 additions and 93 deletions

View File

@@ -103,6 +103,36 @@ void init_ops(nb::module_& m) {
>>> mx.flatten(a, start_axis=0, end_axis=-1)
array([1, 2, 3, 4], dtype=int32)
)pbdoc");
m.def(
"unflatten",
&unflatten,
nb::arg(),
"axis"_a,
"shape"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def unflatten(a: array, /, axis: int, shape: Sequence[int], *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Unflatten an axis of an array to a shape.
Args:
a (array): Input array.
axis (int): The axis to unflatten.
shape (tuple(int)): The shape to unflatten to. At most one
entry can be ``-1`` in which case the corresponding size will be
inferred.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The unflattened array.
Example:
>>> a = mx.array([1, 2, 3, 4])
>>> mx.unflatten(a, 0, (2, -1))
array([[1, 2], [3, 4]], dtype=int32)
)pbdoc");
m.def(
"squeeze",
[](const mx::array& a, const IntOrVec& v, const mx::StreamOrDevice& s) {