mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Flatten and unflatten (#1692)
* flatten and unflatten * fix grad * fix shape infer * use squeeze + unsqueeze in get_item
This commit is contained in:
@@ -117,6 +117,9 @@ array triu(array x, int k = 0, StreamOrDevice s = {});
|
||||
/** Reshape an array to the given shape. */
|
||||
array reshape(const array& a, Shape shape, StreamOrDevice s = {});
|
||||
|
||||
/** Unflatten the axis to the given shape. */
|
||||
array unflatten(const array& a, int axis, Shape shape, StreamOrDevice s = {});
|
||||
|
||||
/** Flatten the dimensions in the range `[start_axis, end_axis]` . */
|
||||
array flatten(
|
||||
const array& a,
|
||||
|
||||
Reference in New Issue
Block a user