diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 572b02a98..61dbc7596 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -168,6 +168,7 @@ Operations tri tril triu + unflatten var view where diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index c4f3658a7..46246812f 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -66,7 +66,6 @@ DEFAULT(Pad) DEFAULT(Partition) DEFAULT_MULTI(QRF) DEFAULT(RandomBits) -DEFAULT(Reshape) DEFAULT(Remainder) DEFAULT(Round) DEFAULT(Scatter) diff --git a/mlx/backend/common/common.cpp b/mlx/backend/common/common.cpp index b0957b54e..e960e6ec6 100644 --- a/mlx/backend/common/common.cpp +++ b/mlx/backend/common/common.cpp @@ -151,9 +151,7 @@ void NumberOfElements::eval(const std::vector& inputs, array& out) { } } -std::pair Reshape::prepare_reshape( - const array& in, - const array& out) { +std::pair prepare_reshape(const array& in, const array& out) { // Special case for empty arrays or row contiguous arrays if (in.size() == 0 || in.flags().row_contiguous) { return {false, out.strides()}; @@ -190,7 +188,7 @@ std::pair Reshape::prepare_reshape( return {copy_necessary, out_strides}; } -void Reshape::shared_buffer_reshape( +void shared_buffer_reshape( const array& in, const Strides& out_strides, array& out) { diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index a79aca81b..5d359aead 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -87,7 +87,6 @@ DEFAULT_MULTI(QRF) DEFAULT(QuantizedMatmul) DEFAULT(RandomBits) DEFAULT(Reduce) -DEFAULT(Reshape) DEFAULT(Round) DEFAULT(Scan) DEFAULT(Scatter) diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index 12042ed0f..2f1f95054 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -19,6 +19,16 @@ namespace mlx::core { +void reshape(const array& in, array& out) { + auto [copy_necessary, out_strides] = prepare_reshape(in, out); + if (copy_necessary) { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + copy_inplace(in, out, CopyType::General); + } else { + shared_buffer_reshape(in, out_strides, out); + } +} + void Abs::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; @@ -258,6 +268,14 @@ void Expm1::eval(const std::vector& inputs, array& out) { } } +void Flatten::eval_cpu(const std::vector& inputs, array& out) { + reshape(inputs[0], out); +} + +void Unflatten::eval_cpu(const std::vector& inputs, array& out) { + reshape(inputs[0], out); +} + void Floor::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; @@ -417,18 +435,8 @@ void Real::eval_cpu(const std::vector& inputs, array& out) { unary_op(inputs[0], out, detail::Real()); } -void Reshape::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - - auto [copy_necessary, out_strides] = prepare_reshape(in, out); - - if (copy_necessary) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); - copy_inplace(in, out, CopyType::General); - } else { - shared_buffer_reshape(in, out_strides, out); - } +void Reshape::eval_cpu(const std::vector& inputs, array& out) { + reshape(inputs[0], out); } void Round::eval(const std::vector& inputs, array& out) { diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index c67189b5d..99d49f150 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -168,4 +168,10 @@ void move_or_copy( size_t data_size, size_t offset = 0); +std::pair prepare_reshape(const array& in, const array& out); + +void shared_buffer_reshape( + const array& in, + const Strides& out_strides, + array& out); } // namespace mlx::core diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 729001604..012a5217f 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -25,6 +25,25 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) { enc.set_bytes(step, 1); } +void reshape(const array& in, array& out, Stream s) { + auto [copy_necessary, out_strides] = prepare_reshape(in, out); + if (copy_necessary) { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + copy_gpu_inplace( + in, + out, + in.shape(), + in.strides(), + make_contiguous_strides(in.shape()), + 0, + 0, + CopyType::General, + s); + } else { + shared_buffer_reshape(in, out_strides, out); + } +} + void Arange::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 0); out.set_data(allocator::malloc_or_wait(out.nbytes())); @@ -215,6 +234,14 @@ void ExpandDims::eval_gpu(const std::vector& inputs, array& out) { eval(inputs, out); } +void Flatten::eval_gpu(const std::vector& inputs, array& out) { + reshape(inputs[0], out, stream()); +} + +void Unflatten::eval_gpu(const std::vector& inputs, array& out) { + reshape(inputs[0], out, stream()); +} + void Load::eval_gpu(const std::vector& inputs, array& out) { out.set_data(allocator::malloc_or_wait(out.nbytes())); auto read_task = [out = out, @@ -309,26 +336,7 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { } void Reshape::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - - auto [copy_necessary, out_strides] = prepare_reshape(in, out); - - if (copy_necessary) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); - copy_gpu_inplace( - in, - out, - in.shape(), - in.strides(), - make_contiguous_strides(in.shape()), - 0, - 0, - CopyType::General, - stream()); - } else { - shared_buffer_reshape(in, out_strides, out); - } + reshape(inputs[0], out, stream()); } void Split::eval_gpu( diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index dab493345..9db4d3983 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -58,6 +58,7 @@ NO_CPU(Exp) NO_CPU(ExpandDims) NO_CPU(Expm1) NO_CPU(FFT) +NO_CPU(Flatten) NO_CPU(Floor) NO_CPU(Full) NO_CPU(Gather) @@ -113,6 +114,7 @@ NO_CPU_MULTI(SVD) NO_CPU(Tan) NO_CPU(Tanh) NO_CPU(Transpose) +NO_CPU(Unflatten) NO_CPU(Inverse) NO_CPU(View) diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index cfae4ff38..f7a34c8e6 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -58,6 +58,7 @@ NO_GPU(Exp) NO_GPU(ExpandDims) NO_GPU(Expm1) NO_GPU(FFT) +NO_GPU(Flatten) NO_GPU(Floor) NO_GPU(Full) NO_GPU(Gather) @@ -113,6 +114,7 @@ NO_GPU_MULTI(SVD) NO_GPU(Tan) NO_GPU(Tanh) NO_GPU(Transpose) +NO_GPU(Unflatten) NO_GPU(Inverse) NO_GPU(Cholesky) NO_GPU_MULTI(Eigh) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 95ff59865..38a4b52d0 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -82,6 +82,7 @@ bool allows_shapeless(const Primitive& p) { typeid(p) == typeid(Transpose) || typeid(p) == typeid(Concatenate) || typeid(p) == typeid(Matmul) || typeid(p) == typeid(QuantizedMatmul) || typeid(p) == typeid(Squeeze) || typeid(p) == typeid(ExpandDims) || + typeid(p) == typeid(Flatten) || typeid(p) == typeid(Unflatten) || typeid(p) == typeid(fast::AffineQuantize) || typeid(p) == typeid(fast::LayerNorm) || typeid(p) == typeid(fast::RMSNorm) || typeid(p) == typeid(fast::RoPE) || diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 53262f1ad..58a0c32a6 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -614,7 +614,7 @@ array scaled_dot_product_attention( auto k = inputs[1]; auto v = inputs[2]; if (n_repeats > 1) { - q = reshape(q, {B, n_kv_heads, n_repeats, L, -1}, s); + q = unflatten(q, 1, {n_kv_heads, n_repeats}, s); k = expand_dims(k, 2, s); v = expand_dims(v, 2, s); } @@ -629,7 +629,7 @@ array scaled_dot_product_attention( scores = softmax(scores, std::vector{-1}, true, s); auto out = matmul(scores, v, s); if (n_repeats > 1) { - out = reshape(out, {B, n_q_heads, L, -1}, s); + out = flatten(out, 1, 2, s); } return std::vector{out}; }; diff --git a/mlx/ops.cpp b/mlx/ops.cpp index c43aba371..8433b4129 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -267,7 +267,7 @@ array as_strided( std::make_shared( to_stream(s), std::move(shape), std::move(strides), offset), // Force the input array to be contiguous. - {reshape(std::move(a), {-1}, s)}); + {flatten(std::move(a), s)}); } array copy(array a, StreamOrDevice s /* = {} */) { @@ -380,10 +380,9 @@ array reshape(const array& a, Shape shape, StreamOrDevice s /* = {} */) { // Infer the shape if (size > 0) { - auto q_and_r = std::ldiv(a.size(), size); if (infer_idx >= 0) { - shape[infer_idx] = q_and_r.quot; - size *= q_and_r.quot; + shape[infer_idx] = a.size() / size; + size *= shape[infer_idx]; } } else if (infer_idx >= 0) { throw std::invalid_argument( @@ -401,6 +400,59 @@ array reshape(const array& a, Shape shape, StreamOrDevice s /* = {} */) { return array(std::move(shape), a.dtype(), std::move(p), {a}); } +array unflatten( + const array& a, + int axis, + Shape shape, + StreamOrDevice s /* = {} */) { + if (shape.empty()) { + throw std::invalid_argument( + "[unflatten] Shape to unflatten to cannot be empty."); + } + auto ndim = static_cast(a.ndim()); + auto ax = axis < 0 ? axis + ndim : axis; + if (ax < 0 || ax >= ndim) { + std::ostringstream msg; + msg << "[unflatten] Invalid axes " << ax << " for array with " << a.ndim() + << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + size_t size = 1; + int infer_idx = -1; + for (int i = 0; i < shape.size(); ++i) { + if (shape[i] == -1) { + if (infer_idx >= 0) { + throw std::invalid_argument( + "[Unflatten] Can only infer one dimension."); + } + infer_idx = i; + } else { + size *= shape[i]; + } + } + if (infer_idx >= 0) { + shape[infer_idx] = a.shape(ax) / size; + size *= shape[infer_idx]; + } + if (size != a.shape(ax)) { + std::ostringstream msg; + msg << "[Unflatten] Cannot unflatten axis " << axis << " with size " + << a.shape(ax) << " into shape " << shape << "."; + throw std::invalid_argument(msg.str()); + } + if (shape.size() == 1) { + return a; + } + + auto out_shape = Unflatten::output_shape(a, ax, shape); + return array( + std::move(out_shape), + a.dtype(), + std::make_shared(to_stream(s), ax, std::move(shape)), + {a}); +} + array flatten( const array& a, int start_axis, @@ -433,11 +485,11 @@ array flatten( if (start_ax == end_ax) { return a; } - Shape new_shape(a.shape().begin(), a.shape().begin() + start_ax); - new_shape.push_back(-1); - new_shape.insert( - new_shape.end(), a.shape().begin() + end_ax + 1, a.shape().end()); - return reshape(a, std::move(new_shape), s); + return array( + Flatten::output_shape(a, start_ax, end_ax), + a.dtype(), + std::make_shared(to_stream(s), start_ax, end_ax), + {a}); } array flatten(const array& a, StreamOrDevice s /* = {} */) { @@ -901,7 +953,7 @@ array concatenate( StreamOrDevice s /* = {} */) { std::vector flat_inputs; for (auto& a : arrays) { - flat_inputs.push_back(reshape(a, {-1}, s)); + flat_inputs.push_back(flatten(a, s)); } return concatenate(flat_inputs, 0, s); } @@ -2568,22 +2620,9 @@ array matmul( } // We can batch the multiplication by reshaping a - if (a.ndim() > 2 && b.ndim() == 2) { - std::vector out_shape = a.shape(); - a = reshape(a, {-1, out_shape.back()}, s); - out_shape.back() = b.shape(-1); - if (in_b.ndim() == 1) { - out_shape.pop_back(); - } - auto out = array( - {a.shape(0), b.shape(1)}, - out_type, - std::make_shared(to_stream(s)), - {a, b}); - return reshape(out, out_shape, s); - } - - if (a.ndim() > 2 || b.ndim() > 2) { + if (in_a.ndim() > 2 && in_b.ndim() <= 2) { + a = flatten(a, 0, -2, s); + } else if (in_b.ndim() > 2) { Shape bsx_a(a.shape().begin(), a.shape().end() - 2); Shape bsx_b(b.shape().begin(), b.shape().end() - 2); auto inner_shape = broadcast_shapes(bsx_a, bsx_b); @@ -2607,6 +2646,11 @@ array matmul( out_type, std::make_shared(to_stream(s)), {a, b}); + if (in_a.ndim() > 2 && in_b.ndim() <= 2) { + auto orig_shape = in_a.shape(); + orig_shape.pop_back(); + out = unflatten(out, 0, std::move(orig_shape), s); + } // Remove the possibly inserted singleton dimensions std::vector axes; @@ -2753,7 +2797,7 @@ array take( } array take(const array& a, const array& indices, StreamOrDevice s /* = {} */) { - return take(reshape(a, {-1}, s), indices, 0, s); + return take(flatten(a, s), indices, 0, s); } array take(const array& a, int index, int axis, StreamOrDevice s /* = {} */) { @@ -2783,7 +2827,7 @@ array take(const array& a, int index, int axis, StreamOrDevice s /* = {} */) { } array take(const array& a, int index, StreamOrDevice s /* = {} */) { - return take(reshape(a, {-1}, s), index, 0, s); + return take(flatten(a, s), index, 0, s); } array take_along_axis( @@ -3853,11 +3897,11 @@ array addmm( if (a.ndim() == 1) { // Insert a singleton dim in the beginning - a = reshape(a, {1, -1}, s); + a = expand_dims(a, 0, s); } if (b.ndim() == 1) { // Insert a singleton dim at the end - b = reshape(b, {-1, 1}, s); + b = expand_dims(b, 1, s); } if (a.shape(-1) != b.shape(-2)) { @@ -4644,7 +4688,7 @@ array roll( array roll(const array& a, int shift, StreamOrDevice s /* = {} */) { auto shape = a.shape(); return reshape( - roll(reshape(a, Shape{-1}, s), Shape{shift}, std::vector{0}, s), + roll(flatten(a, s), Shape{shift}, std::vector{0}, s), std::move(shape), s); } diff --git a/mlx/ops.h b/mlx/ops.h index 7e24b5820..7576774b5 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -117,6 +117,9 @@ array triu(array x, int k = 0, StreamOrDevice s = {}); /** Reshape an array to the given shape. */ array reshape(const array& a, Shape shape, StreamOrDevice s = {}); +/** Unflatten the axis to the given shape. */ +array unflatten(const array& a, int axis, Shape shape, StreamOrDevice s = {}); + /** Flatten the dimensions in the range `[start_axis, end_axis]` . */ array flatten( const array& a, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index e139f7385..fa7b384f5 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1651,12 +1651,114 @@ std::vector ExpandDims::output_shapes(const std::vector& inputs) { return {ExpandDims::output_shape(inputs[0], axes_)}; } +std::vector Flatten::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector&, + const std::vector&) { + auto& in = primals[0]; + Shape unflatten_shape( + in.shape().begin() + start_axis_, in.shape().begin() + end_axis_ + 1); + return {unflatten( + cotangents[0], start_axis_, std::move(unflatten_shape), stream())}; +} + +std::vector Flatten::jvp( + const std::vector&, + const std::vector& tangents, + const std::vector&) { + return {flatten(tangents[0], start_axis_, end_axis_, stream())}; +} + +std::pair, std::vector> Flatten::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto ax = axes[0]; + auto start_axis = start_axis_; + auto end_axis = end_axis_; + if (ax < start_axis) { + start_axis++; + end_axis++; + } else { + ax -= (end_axis - start_axis); + } + return {{flatten(inputs[0], start_axis, end_axis, stream())}, {ax}}; +} + +bool Flatten::is_equivalent(const Primitive& other) const { + const Flatten& a_other = static_cast(other); + return start_axis_ == a_other.start_axis_ && end_axis_ == a_other.end_axis_; +} + +Shape Flatten::output_shape(const array& input, int start_axis, int end_axis) { + Shape shape = input.shape(); + auto flat_size = input.shape(start_axis); + for (int ax = start_axis + 1; ax <= end_axis; ++ax) { + flat_size *= input.shape(ax); + } + shape.erase(shape.begin() + start_axis + 1, shape.begin() + end_axis + 1); + shape[start_axis] = flat_size; + return shape; +} + +std::vector Flatten::output_shapes(const std::vector& inputs) { + return {Flatten::output_shape(inputs[0], start_axis_, end_axis_)}; +} + bool FFT::is_equivalent(const Primitive& other) const { const FFT& r_other = static_cast(other); return axes_ == r_other.axes_ && inverse_ == r_other.inverse_ && real_ == r_other.real_; } +std::vector Unflatten::vjp( + const std::vector&, + const std::vector& cotangents, + const std::vector&, + const std::vector&) { + return {flatten(cotangents[0], axis_, axis_ + shape_.size(), stream())}; +} + +std::vector Unflatten::jvp( + const std::vector&, + const std::vector& tangents, + const std::vector&) { + return {unflatten(tangents[0], axis_, shape_, stream())}; +} + +std::pair, std::vector> Unflatten::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto ax = axes[0]; + auto axis = axis_; + if (ax <= axis_) { + axis++; + } else { + ax += (shape_.size() - 1); + } + return {{unflatten(inputs[0], axis, shape_, stream())}, {ax}}; +} + +bool Unflatten::is_equivalent(const Primitive& other) const { + const auto& a_other = static_cast(other); + return axis_ == a_other.axis_ && shape_ == a_other.shape_; +} + +Shape Unflatten::output_shape( + const array& input, + int axis, + const Shape& shape) { + Shape out_shape = input.shape(); + out_shape[axis] = shape[0]; + out_shape.insert( + out_shape.begin() + axis + 1, shape.begin() + 1, shape.end()); + return out_shape; +} + +std::vector Unflatten::output_shapes(const std::vector& inputs) { + return {Unflatten::output_shape(inputs[0], axis_, shape_)}; +} + std::pair, std::vector> FFT::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/mlx/primitives.h b/mlx/primitives.h index 1ca913a37..55a87cf18 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1031,6 +1031,28 @@ class FFT : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; +class Flatten : public UnaryPrimitive { + public: + explicit Flatten(Stream stream, int start_axis, int end_axis) + : UnaryPrimitive(stream), start_axis_(start_axis), end_axis_(end_axis) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_PRINT(Flatten) + std::vector output_shapes(const std::vector& inputs) override; + bool is_equivalent(const Primitive& other) const override; + + static Shape output_shape(const array& input, int start_axis, int end_axis); + + private: + int start_axis_; + int end_axis_; + void eval(const std::vector& inputs, array& out); +}; + class Floor : public UnaryPrimitive { public: explicit Floor(Stream stream) : UnaryPrimitive(stream) {} @@ -1643,16 +1665,6 @@ class Reshape : public UnaryPrimitive { private: Shape shape_; - - void eval(const std::vector& inputs, array& out); - - static std::pair prepare_reshape( - const array& in, - const array& out); - static void shared_buffer_reshape( - const array& in, - const Strides& out_strides, - array& out); }; class Reduce : public UnaryPrimitive { @@ -2137,6 +2149,28 @@ class Tanh : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; +class Unflatten : public UnaryPrimitive { + public: + explicit Unflatten(Stream stream, int axis, Shape shape) + : UnaryPrimitive(stream), axis_(axis), shape_(std::move(shape)) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_PRINT(Unflatten) + std::vector output_shapes(const std::vector& inputs) override; + bool is_equivalent(const Primitive& other) const override; + + static Shape output_shape(const array& input, int axis, const Shape& shape); + + private: + int axis_; + Shape shape_; + void eval(const std::vector& inputs, array& out); +}; + class Uniform : public UnaryPrimitive { public: explicit Uniform(Stream stream) : UnaryPrimitive(stream) {} diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index 6261f2603..6278f5b99 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -405,22 +405,22 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) { // Unsqueeze handling if (unsqueeze_needed || squeeze_needed) { - std::vector out_shape; - int axis = 0; - for (auto& idx : remaining_indices) { + std::vector squeeze_axes; + std::vector unsqueeze_axes; + for (int axis = 0; axis < remaining_indices.size(); ++axis) { + auto& idx = remaining_indices[axis]; if (unsqueeze_needed && idx.is_none()) { - out_shape.push_back(1); + unsqueeze_axes.push_back(axis - squeeze_axes.size()); } else if (squeeze_needed && nb::isinstance(idx)) { - axis++; - } else { - out_shape.push_back(src.shape(axis++)); + squeeze_axes.push_back(axis - unsqueeze_axes.size()); } } - - out_shape.insert( - out_shape.end(), src.shape().begin() + axis, src.shape().end()); - - src = reshape(src, out_shape); + if (!squeeze_axes.empty()) { + src = squeeze(src, std::move(squeeze_axes)); + } + if (!unsqueeze_axes.empty()) { + src = expand_dims(src, std::move(unsqueeze_axes)); + } } return src; diff --git a/python/src/ops.cpp b/python/src/ops.cpp index d268a865e..bd64bf687 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -103,6 +103,36 @@ void init_ops(nb::module_& m) { >>> mx.flatten(a, start_axis=0, end_axis=-1) array([1, 2, 3, 4], dtype=int32) )pbdoc"); + m.def( + "unflatten", + &unflatten, + nb::arg(), + "axis"_a, + "shape"_a, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def unflatten(a: array, /, axis: int, shape: Sequence[int], *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Unflatten an axis of an array to a shape. + + Args: + a (array): Input array. + axis (int): The axis to unflatten. + shape (tuple(int)): The shape to unflatten to. At most one + entry can be ``-1`` in which case the corresponding size will be + inferred. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: The unflattened array. + + Example: + >>> a = mx.array([1, 2, 3, 4]) + >>> mx.unflatten(a, 0, (2, -1)) + array([[1, 2], [3, 4]], dtype=int32) + )pbdoc"); m.def( "squeeze", [](const mx::array& a, const IntOrVec& v, const mx::StreamOrDevice& s) { diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index c0d5e3647..42646bfe1 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -462,6 +462,22 @@ class TestCompile(mlx_tests.MLXTestCase): cfun = mx.compile(fun, shapeless=True) self.assertTrue(mx.array_equal(fun(x2), cfun(x2))) + def test_shapeless_compile_unflatten(self): + x = mx.zeros((1, 1, 4 * 32)) + + def fun(x): + return mx.unflatten(x, -1, (4, -1)) + + self.assertEqual(mx.compile(fun, shapeless=True)(x).shape, (1, 1, 4, 32)) + + def test_shapeless_compile_gather(self): + x = mx.zeros((1, 1, 32)) + + def fun(x): + return x[:, -1, :] + + self.assertEqual(mx.compile(fun, shapeless=True)(x).shape, (1, 32)) + def test_compile_with_constant(self): # Test float @partial(mx.compile) diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index ae830262b..6842717cc 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -163,6 +163,23 @@ TEST_CASE("test flatten") { CHECK_EQ(flatten(x, 0, 0).shape(), Shape({1})); } +TEST_CASE("test unflatten") { + array x = array(1); + CHECK_THROWS(unflatten(x, 0, {1, 1})); + + x = array({1}); + auto out = unflatten(x, 0, {1, 1}); + CHECK_EQ(out.shape(), Shape({1, 1})); + CHECK_THROWS(unflatten(x, 1, {1, 1})); + CHECK_THROWS(unflatten(x, 0, {-1, -1})); + CHECK_THROWS(unflatten(x, 0, {-1, 2})); + CHECK_THROWS(unflatten(x, 0, {})); + + x = zeros({4, 8}); + out = unflatten(x, 1, {2, 2, 2}); + CHECK_EQ(out.shape(), Shape({4, 2, 2, 2})); +} + TEST_CASE("test squeeze and expand") { array x = zeros({2, 1, 2, 1, 2, 1}); CHECK_EQ(squeeze(x).shape(), Shape{2, 2, 2});