Flatten and unflatten (#1692)

* flatten and unflatten

* fix grad

* fix shape infer

* use squeeze + unsqueeze in get_item
This commit is contained in:
Awni Hannun 2024-12-11 21:51:37 -08:00 committed by GitHub
parent 0bf19037ca
commit 4e1e9520e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 363 additions and 93 deletions

View File

@ -168,6 +168,7 @@ Operations
tri
tril
triu
unflatten
var
view
where

View File

@ -66,7 +66,6 @@ DEFAULT(Pad)
DEFAULT(Partition)
DEFAULT_MULTI(QRF)
DEFAULT(RandomBits)
DEFAULT(Reshape)
DEFAULT(Remainder)
DEFAULT(Round)
DEFAULT(Scatter)

View File

@ -151,9 +151,7 @@ void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
}
}
std::pair<bool, Strides> Reshape::prepare_reshape(
const array& in,
const array& out) {
std::pair<bool, Strides> prepare_reshape(const array& in, const array& out) {
// Special case for empty arrays or row contiguous arrays
if (in.size() == 0 || in.flags().row_contiguous) {
return {false, out.strides()};
@ -190,7 +188,7 @@ std::pair<bool, Strides> Reshape::prepare_reshape(
return {copy_necessary, out_strides};
}
void Reshape::shared_buffer_reshape(
void shared_buffer_reshape(
const array& in,
const Strides& out_strides,
array& out) {

View File

@ -87,7 +87,6 @@ DEFAULT_MULTI(QRF)
DEFAULT(QuantizedMatmul)
DEFAULT(RandomBits)
DEFAULT(Reduce)
DEFAULT(Reshape)
DEFAULT(Round)
DEFAULT(Scan)
DEFAULT(Scatter)

View File

@ -19,6 +19,16 @@
namespace mlx::core {
void reshape(const array& in, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
copy_inplace(in, out, CopyType::General);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
void Abs::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
@ -258,6 +268,14 @@ void Expm1::eval(const std::vector<array>& inputs, array& out) {
}
}
void Flatten::eval_cpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out);
}
void Unflatten::eval_cpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out);
}
void Floor::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
@ -417,18 +435,8 @@ void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
}
void Reshape::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
copy_inplace(in, out, CopyType::General);
} else {
shared_buffer_reshape(in, out_strides, out);
}
void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out);
}
void Round::eval(const std::vector<array>& inputs, array& out) {

View File

@ -168,4 +168,10 @@ void move_or_copy(
size_t data_size,
size_t offset = 0);
std::pair<bool, Strides> prepare_reshape(const array& in, const array& out);
void shared_buffer_reshape(
const array& in,
const Strides& out_strides,
array& out);
} // namespace mlx::core

View File

@ -25,6 +25,25 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
enc.set_bytes(step, 1);
}
void reshape(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
@ -215,6 +234,14 @@ void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out, stream());
}
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out, stream());
}
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto read_task = [out = out,
@ -309,26 +336,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
}
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
stream());
} else {
shared_buffer_reshape(in, out_strides, out);
}
reshape(inputs[0], out, stream());
}
void Split::eval_gpu(

View File

@ -58,6 +58,7 @@ NO_CPU(Exp)
NO_CPU(ExpandDims)
NO_CPU(Expm1)
NO_CPU(FFT)
NO_CPU(Flatten)
NO_CPU(Floor)
NO_CPU(Full)
NO_CPU(Gather)
@ -113,6 +114,7 @@ NO_CPU_MULTI(SVD)
NO_CPU(Tan)
NO_CPU(Tanh)
NO_CPU(Transpose)
NO_CPU(Unflatten)
NO_CPU(Inverse)
NO_CPU(View)

View File

@ -58,6 +58,7 @@ NO_GPU(Exp)
NO_GPU(ExpandDims)
NO_GPU(Expm1)
NO_GPU(FFT)
NO_GPU(Flatten)
NO_GPU(Floor)
NO_GPU(Full)
NO_GPU(Gather)
@ -113,6 +114,7 @@ NO_GPU_MULTI(SVD)
NO_GPU(Tan)
NO_GPU(Tanh)
NO_GPU(Transpose)
NO_GPU(Unflatten)
NO_GPU(Inverse)
NO_GPU(Cholesky)
NO_GPU_MULTI(Eigh)

View File

@ -82,6 +82,7 @@ bool allows_shapeless(const Primitive& p) {
typeid(p) == typeid(Transpose) || typeid(p) == typeid(Concatenate) ||
typeid(p) == typeid(Matmul) || typeid(p) == typeid(QuantizedMatmul) ||
typeid(p) == typeid(Squeeze) || typeid(p) == typeid(ExpandDims) ||
typeid(p) == typeid(Flatten) || typeid(p) == typeid(Unflatten) ||
typeid(p) == typeid(fast::AffineQuantize) ||
typeid(p) == typeid(fast::LayerNorm) ||
typeid(p) == typeid(fast::RMSNorm) || typeid(p) == typeid(fast::RoPE) ||

View File

@ -614,7 +614,7 @@ array scaled_dot_product_attention(
auto k = inputs[1];
auto v = inputs[2];
if (n_repeats > 1) {
q = reshape(q, {B, n_kv_heads, n_repeats, L, -1}, s);
q = unflatten(q, 1, {n_kv_heads, n_repeats}, s);
k = expand_dims(k, 2, s);
v = expand_dims(v, 2, s);
}
@ -629,7 +629,7 @@ array scaled_dot_product_attention(
scores = softmax(scores, std::vector<int>{-1}, true, s);
auto out = matmul(scores, v, s);
if (n_repeats > 1) {
out = reshape(out, {B, n_q_heads, L, -1}, s);
out = flatten(out, 1, 2, s);
}
return std::vector<array>{out};
};

View File

@ -267,7 +267,7 @@ array as_strided(
std::make_shared<AsStrided>(
to_stream(s), std::move(shape), std::move(strides), offset),
// Force the input array to be contiguous.
{reshape(std::move(a), {-1}, s)});
{flatten(std::move(a), s)});
}
array copy(array a, StreamOrDevice s /* = {} */) {
@ -380,10 +380,9 @@ array reshape(const array& a, Shape shape, StreamOrDevice s /* = {} */) {
// Infer the shape
if (size > 0) {
auto q_and_r = std::ldiv(a.size(), size);
if (infer_idx >= 0) {
shape[infer_idx] = q_and_r.quot;
size *= q_and_r.quot;
shape[infer_idx] = a.size() / size;
size *= shape[infer_idx];
}
} else if (infer_idx >= 0) {
throw std::invalid_argument(
@ -401,6 +400,59 @@ array reshape(const array& a, Shape shape, StreamOrDevice s /* = {} */) {
return array(std::move(shape), a.dtype(), std::move(p), {a});
}
array unflatten(
const array& a,
int axis,
Shape shape,
StreamOrDevice s /* = {} */) {
if (shape.empty()) {
throw std::invalid_argument(
"[unflatten] Shape to unflatten to cannot be empty.");
}
auto ndim = static_cast<int>(a.ndim());
auto ax = axis < 0 ? axis + ndim : axis;
if (ax < 0 || ax >= ndim) {
std::ostringstream msg;
msg << "[unflatten] Invalid axes " << ax << " for array with " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
size_t size = 1;
int infer_idx = -1;
for (int i = 0; i < shape.size(); ++i) {
if (shape[i] == -1) {
if (infer_idx >= 0) {
throw std::invalid_argument(
"[Unflatten] Can only infer one dimension.");
}
infer_idx = i;
} else {
size *= shape[i];
}
}
if (infer_idx >= 0) {
shape[infer_idx] = a.shape(ax) / size;
size *= shape[infer_idx];
}
if (size != a.shape(ax)) {
std::ostringstream msg;
msg << "[Unflatten] Cannot unflatten axis " << axis << " with size "
<< a.shape(ax) << " into shape " << shape << ".";
throw std::invalid_argument(msg.str());
}
if (shape.size() == 1) {
return a;
}
auto out_shape = Unflatten::output_shape(a, ax, shape);
return array(
std::move(out_shape),
a.dtype(),
std::make_shared<Unflatten>(to_stream(s), ax, std::move(shape)),
{a});
}
array flatten(
const array& a,
int start_axis,
@ -433,11 +485,11 @@ array flatten(
if (start_ax == end_ax) {
return a;
}
Shape new_shape(a.shape().begin(), a.shape().begin() + start_ax);
new_shape.push_back(-1);
new_shape.insert(
new_shape.end(), a.shape().begin() + end_ax + 1, a.shape().end());
return reshape(a, std::move(new_shape), s);
return array(
Flatten::output_shape(a, start_ax, end_ax),
a.dtype(),
std::make_shared<Flatten>(to_stream(s), start_ax, end_ax),
{a});
}
array flatten(const array& a, StreamOrDevice s /* = {} */) {
@ -901,7 +953,7 @@ array concatenate(
StreamOrDevice s /* = {} */) {
std::vector<array> flat_inputs;
for (auto& a : arrays) {
flat_inputs.push_back(reshape(a, {-1}, s));
flat_inputs.push_back(flatten(a, s));
}
return concatenate(flat_inputs, 0, s);
}
@ -2568,22 +2620,9 @@ array matmul(
}
// We can batch the multiplication by reshaping a
if (a.ndim() > 2 && b.ndim() == 2) {
std::vector<int> out_shape = a.shape();
a = reshape(a, {-1, out_shape.back()}, s);
out_shape.back() = b.shape(-1);
if (in_b.ndim() == 1) {
out_shape.pop_back();
}
auto out = array(
{a.shape(0), b.shape(1)},
out_type,
std::make_shared<Matmul>(to_stream(s)),
{a, b});
return reshape(out, out_shape, s);
}
if (a.ndim() > 2 || b.ndim() > 2) {
if (in_a.ndim() > 2 && in_b.ndim() <= 2) {
a = flatten(a, 0, -2, s);
} else if (in_b.ndim() > 2) {
Shape bsx_a(a.shape().begin(), a.shape().end() - 2);
Shape bsx_b(b.shape().begin(), b.shape().end() - 2);
auto inner_shape = broadcast_shapes(bsx_a, bsx_b);
@ -2607,6 +2646,11 @@ array matmul(
out_type,
std::make_shared<Matmul>(to_stream(s)),
{a, b});
if (in_a.ndim() > 2 && in_b.ndim() <= 2) {
auto orig_shape = in_a.shape();
orig_shape.pop_back();
out = unflatten(out, 0, std::move(orig_shape), s);
}
// Remove the possibly inserted singleton dimensions
std::vector<int> axes;
@ -2753,7 +2797,7 @@ array take(
}
array take(const array& a, const array& indices, StreamOrDevice s /* = {} */) {
return take(reshape(a, {-1}, s), indices, 0, s);
return take(flatten(a, s), indices, 0, s);
}
array take(const array& a, int index, int axis, StreamOrDevice s /* = {} */) {
@ -2783,7 +2827,7 @@ array take(const array& a, int index, int axis, StreamOrDevice s /* = {} */) {
}
array take(const array& a, int index, StreamOrDevice s /* = {} */) {
return take(reshape(a, {-1}, s), index, 0, s);
return take(flatten(a, s), index, 0, s);
}
array take_along_axis(
@ -3853,11 +3897,11 @@ array addmm(
if (a.ndim() == 1) {
// Insert a singleton dim in the beginning
a = reshape(a, {1, -1}, s);
a = expand_dims(a, 0, s);
}
if (b.ndim() == 1) {
// Insert a singleton dim at the end
b = reshape(b, {-1, 1}, s);
b = expand_dims(b, 1, s);
}
if (a.shape(-1) != b.shape(-2)) {
@ -4644,7 +4688,7 @@ array roll(
array roll(const array& a, int shift, StreamOrDevice s /* = {} */) {
auto shape = a.shape();
return reshape(
roll(reshape(a, Shape{-1}, s), Shape{shift}, std::vector<int>{0}, s),
roll(flatten(a, s), Shape{shift}, std::vector<int>{0}, s),
std::move(shape),
s);
}

View File

@ -117,6 +117,9 @@ array triu(array x, int k = 0, StreamOrDevice s = {});
/** Reshape an array to the given shape. */
array reshape(const array& a, Shape shape, StreamOrDevice s = {});
/** Unflatten the axis to the given shape. */
array unflatten(const array& a, int axis, Shape shape, StreamOrDevice s = {});
/** Flatten the dimensions in the range `[start_axis, end_axis]` . */
array flatten(
const array& a,

View File

@ -1651,12 +1651,114 @@ std::vector<Shape> ExpandDims::output_shapes(const std::vector<array>& inputs) {
return {ExpandDims::output_shape(inputs[0], axes_)};
}
std::vector<array> Flatten::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>&,
const std::vector<array>&) {
auto& in = primals[0];
Shape unflatten_shape(
in.shape().begin() + start_axis_, in.shape().begin() + end_axis_ + 1);
return {unflatten(
cotangents[0], start_axis_, std::move(unflatten_shape), stream())};
}
std::vector<array> Flatten::jvp(
const std::vector<array>&,
const std::vector<array>& tangents,
const std::vector<int>&) {
return {flatten(tangents[0], start_axis_, end_axis_, stream())};
}
std::pair<std::vector<array>, std::vector<int>> Flatten::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto ax = axes[0];
auto start_axis = start_axis_;
auto end_axis = end_axis_;
if (ax < start_axis) {
start_axis++;
end_axis++;
} else {
ax -= (end_axis - start_axis);
}
return {{flatten(inputs[0], start_axis, end_axis, stream())}, {ax}};
}
bool Flatten::is_equivalent(const Primitive& other) const {
const Flatten& a_other = static_cast<const Flatten&>(other);
return start_axis_ == a_other.start_axis_ && end_axis_ == a_other.end_axis_;
}
Shape Flatten::output_shape(const array& input, int start_axis, int end_axis) {
Shape shape = input.shape();
auto flat_size = input.shape(start_axis);
for (int ax = start_axis + 1; ax <= end_axis; ++ax) {
flat_size *= input.shape(ax);
}
shape.erase(shape.begin() + start_axis + 1, shape.begin() + end_axis + 1);
shape[start_axis] = flat_size;
return shape;
}
std::vector<Shape> Flatten::output_shapes(const std::vector<array>& inputs) {
return {Flatten::output_shape(inputs[0], start_axis_, end_axis_)};
}
bool FFT::is_equivalent(const Primitive& other) const {
const FFT& r_other = static_cast<const FFT&>(other);
return axes_ == r_other.axes_ && inverse_ == r_other.inverse_ &&
real_ == r_other.real_;
}
std::vector<array> Unflatten::vjp(
const std::vector<array>&,
const std::vector<array>& cotangents,
const std::vector<int>&,
const std::vector<array>&) {
return {flatten(cotangents[0], axis_, axis_ + shape_.size(), stream())};
}
std::vector<array> Unflatten::jvp(
const std::vector<array>&,
const std::vector<array>& tangents,
const std::vector<int>&) {
return {unflatten(tangents[0], axis_, shape_, stream())};
}
std::pair<std::vector<array>, std::vector<int>> Unflatten::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto ax = axes[0];
auto axis = axis_;
if (ax <= axis_) {
axis++;
} else {
ax += (shape_.size() - 1);
}
return {{unflatten(inputs[0], axis, shape_, stream())}, {ax}};
}
bool Unflatten::is_equivalent(const Primitive& other) const {
const auto& a_other = static_cast<const Unflatten&>(other);
return axis_ == a_other.axis_ && shape_ == a_other.shape_;
}
Shape Unflatten::output_shape(
const array& input,
int axis,
const Shape& shape) {
Shape out_shape = input.shape();
out_shape[axis] = shape[0];
out_shape.insert(
out_shape.begin() + axis + 1, shape.begin() + 1, shape.end());
return out_shape;
}
std::vector<Shape> Unflatten::output_shapes(const std::vector<array>& inputs) {
return {Unflatten::output_shape(inputs[0], axis_, shape_)};
}
std::pair<std::vector<array>, std::vector<int>> FFT::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {

View File

@ -1031,6 +1031,28 @@ class FFT : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class Flatten : public UnaryPrimitive {
public:
explicit Flatten(Stream stream, int start_axis, int end_axis)
: UnaryPrimitive(stream), start_axis_(start_axis), end_axis_(end_axis) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Flatten)
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
bool is_equivalent(const Primitive& other) const override;
static Shape output_shape(const array& input, int start_axis, int end_axis);
private:
int start_axis_;
int end_axis_;
void eval(const std::vector<array>& inputs, array& out);
};
class Floor : public UnaryPrimitive {
public:
explicit Floor(Stream stream) : UnaryPrimitive(stream) {}
@ -1643,16 +1665,6 @@ class Reshape : public UnaryPrimitive {
private:
Shape shape_;
void eval(const std::vector<array>& inputs, array& out);
static std::pair<bool, Strides> prepare_reshape(
const array& in,
const array& out);
static void shared_buffer_reshape(
const array& in,
const Strides& out_strides,
array& out);
};
class Reduce : public UnaryPrimitive {
@ -2137,6 +2149,28 @@ class Tanh : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class Unflatten : public UnaryPrimitive {
public:
explicit Unflatten(Stream stream, int axis, Shape shape)
: UnaryPrimitive(stream), axis_(axis), shape_(std::move(shape)) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Unflatten)
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
bool is_equivalent(const Primitive& other) const override;
static Shape output_shape(const array& input, int axis, const Shape& shape);
private:
int axis_;
Shape shape_;
void eval(const std::vector<array>& inputs, array& out);
};
class Uniform : public UnaryPrimitive {
public:
explicit Uniform(Stream stream) : UnaryPrimitive(stream) {}

View File

@ -405,22 +405,22 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
// Unsqueeze handling
if (unsqueeze_needed || squeeze_needed) {
std::vector<int> out_shape;
int axis = 0;
for (auto& idx : remaining_indices) {
std::vector<int> squeeze_axes;
std::vector<int> unsqueeze_axes;
for (int axis = 0; axis < remaining_indices.size(); ++axis) {
auto& idx = remaining_indices[axis];
if (unsqueeze_needed && idx.is_none()) {
out_shape.push_back(1);
unsqueeze_axes.push_back(axis - squeeze_axes.size());
} else if (squeeze_needed && nb::isinstance<nb::int_>(idx)) {
axis++;
} else {
out_shape.push_back(src.shape(axis++));
squeeze_axes.push_back(axis - unsqueeze_axes.size());
}
}
out_shape.insert(
out_shape.end(), src.shape().begin() + axis, src.shape().end());
src = reshape(src, out_shape);
if (!squeeze_axes.empty()) {
src = squeeze(src, std::move(squeeze_axes));
}
if (!unsqueeze_axes.empty()) {
src = expand_dims(src, std::move(unsqueeze_axes));
}
}
return src;

View File

@ -103,6 +103,36 @@ void init_ops(nb::module_& m) {
>>> mx.flatten(a, start_axis=0, end_axis=-1)
array([1, 2, 3, 4], dtype=int32)
)pbdoc");
m.def(
"unflatten",
&unflatten,
nb::arg(),
"axis"_a,
"shape"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def unflatten(a: array, /, axis: int, shape: Sequence[int], *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Unflatten an axis of an array to a shape.
Args:
a (array): Input array.
axis (int): The axis to unflatten.
shape (tuple(int)): The shape to unflatten to. At most one
entry can be ``-1`` in which case the corresponding size will be
inferred.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The unflattened array.
Example:
>>> a = mx.array([1, 2, 3, 4])
>>> mx.unflatten(a, 0, (2, -1))
array([[1, 2], [3, 4]], dtype=int32)
)pbdoc");
m.def(
"squeeze",
[](const mx::array& a, const IntOrVec& v, const mx::StreamOrDevice& s) {

View File

@ -462,6 +462,22 @@ class TestCompile(mlx_tests.MLXTestCase):
cfun = mx.compile(fun, shapeless=True)
self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))
def test_shapeless_compile_unflatten(self):
x = mx.zeros((1, 1, 4 * 32))
def fun(x):
return mx.unflatten(x, -1, (4, -1))
self.assertEqual(mx.compile(fun, shapeless=True)(x).shape, (1, 1, 4, 32))
def test_shapeless_compile_gather(self):
x = mx.zeros((1, 1, 32))
def fun(x):
return x[:, -1, :]
self.assertEqual(mx.compile(fun, shapeless=True)(x).shape, (1, 32))
def test_compile_with_constant(self):
# Test float
@partial(mx.compile)

View File

@ -163,6 +163,23 @@ TEST_CASE("test flatten") {
CHECK_EQ(flatten(x, 0, 0).shape(), Shape({1}));
}
TEST_CASE("test unflatten") {
array x = array(1);
CHECK_THROWS(unflatten(x, 0, {1, 1}));
x = array({1});
auto out = unflatten(x, 0, {1, 1});
CHECK_EQ(out.shape(), Shape({1, 1}));
CHECK_THROWS(unflatten(x, 1, {1, 1}));
CHECK_THROWS(unflatten(x, 0, {-1, -1}));
CHECK_THROWS(unflatten(x, 0, {-1, 2}));
CHECK_THROWS(unflatten(x, 0, {}));
x = zeros({4, 8});
out = unflatten(x, 1, {2, 2, 2});
CHECK_EQ(out.shape(), Shape({4, 2, 2, 2}));
}
TEST_CASE("test squeeze and expand") {
array x = zeros({2, 1, 2, 1, 2, 1});
CHECK_EQ(squeeze(x).shape(), Shape{2, 2, 2});