mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-17 14:31:14 +08:00
Flatten and unflatten (#1692)
* flatten and unflatten * fix grad * fix shape infer * use squeeze + unsqueeze in get_item
This commit is contained in:
parent
0bf19037ca
commit
4e1e9520e1
@ -168,6 +168,7 @@ Operations
|
||||
tri
|
||||
tril
|
||||
triu
|
||||
unflatten
|
||||
var
|
||||
view
|
||||
where
|
||||
|
@ -66,7 +66,6 @@ DEFAULT(Pad)
|
||||
DEFAULT(Partition)
|
||||
DEFAULT_MULTI(QRF)
|
||||
DEFAULT(RandomBits)
|
||||
DEFAULT(Reshape)
|
||||
DEFAULT(Remainder)
|
||||
DEFAULT(Round)
|
||||
DEFAULT(Scatter)
|
||||
|
@ -151,9 +151,7 @@ void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<bool, Strides> Reshape::prepare_reshape(
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::pair<bool, Strides> prepare_reshape(const array& in, const array& out) {
|
||||
// Special case for empty arrays or row contiguous arrays
|
||||
if (in.size() == 0 || in.flags().row_contiguous) {
|
||||
return {false, out.strides()};
|
||||
@ -190,7 +188,7 @@ std::pair<bool, Strides> Reshape::prepare_reshape(
|
||||
return {copy_necessary, out_strides};
|
||||
}
|
||||
|
||||
void Reshape::shared_buffer_reshape(
|
||||
void shared_buffer_reshape(
|
||||
const array& in,
|
||||
const Strides& out_strides,
|
||||
array& out) {
|
||||
|
@ -87,7 +87,6 @@ DEFAULT_MULTI(QRF)
|
||||
DEFAULT(QuantizedMatmul)
|
||||
DEFAULT(RandomBits)
|
||||
DEFAULT(Reduce)
|
||||
DEFAULT(Reshape)
|
||||
DEFAULT(Round)
|
||||
DEFAULT(Scan)
|
||||
DEFAULT(Scatter)
|
||||
|
@ -19,6 +19,16 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void reshape(const array& in, array& out) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
copy_inplace(in, out, CopyType::General);
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Abs::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
@ -258,6 +268,14 @@ void Expm1::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Flatten::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out);
|
||||
}
|
||||
|
||||
void Unflatten::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out);
|
||||
}
|
||||
|
||||
void Floor::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
@ -417,18 +435,8 @@ void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
|
||||
}
|
||||
|
||||
void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
copy_inplace(in, out, CopyType::General);
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out);
|
||||
}
|
||||
|
||||
void Round::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
@ -168,4 +168,10 @@ void move_or_copy(
|
||||
size_t data_size,
|
||||
size_t offset = 0);
|
||||
|
||||
std::pair<bool, Strides> prepare_reshape(const array& in, const array& out);
|
||||
|
||||
void shared_buffer_reshape(
|
||||
const array& in,
|
||||
const Strides& out_strides,
|
||||
array& out);
|
||||
} // namespace mlx::core
|
||||
|
@ -25,6 +25,25 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
|
||||
enc.set_bytes(step, 1);
|
||||
}
|
||||
|
||||
void reshape(const array& in, array& out, Stream s) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
copy_gpu_inplace(
|
||||
in,
|
||||
out,
|
||||
in.shape(),
|
||||
in.strides(),
|
||||
make_contiguous_strides(in.shape()),
|
||||
0,
|
||||
0,
|
||||
CopyType::General,
|
||||
s);
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
@ -215,6 +234,14 @@ void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out, stream());
|
||||
}
|
||||
|
||||
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out, stream());
|
||||
}
|
||||
|
||||
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto read_task = [out = out,
|
||||
@ -309,26 +336,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
copy_gpu_inplace(
|
||||
in,
|
||||
out,
|
||||
in.shape(),
|
||||
in.strides(),
|
||||
make_contiguous_strides(in.shape()),
|
||||
0,
|
||||
0,
|
||||
CopyType::General,
|
||||
stream());
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
reshape(inputs[0], out, stream());
|
||||
}
|
||||
|
||||
void Split::eval_gpu(
|
||||
|
@ -58,6 +58,7 @@ NO_CPU(Exp)
|
||||
NO_CPU(ExpandDims)
|
||||
NO_CPU(Expm1)
|
||||
NO_CPU(FFT)
|
||||
NO_CPU(Flatten)
|
||||
NO_CPU(Floor)
|
||||
NO_CPU(Full)
|
||||
NO_CPU(Gather)
|
||||
@ -113,6 +114,7 @@ NO_CPU_MULTI(SVD)
|
||||
NO_CPU(Tan)
|
||||
NO_CPU(Tanh)
|
||||
NO_CPU(Transpose)
|
||||
NO_CPU(Unflatten)
|
||||
NO_CPU(Inverse)
|
||||
NO_CPU(View)
|
||||
|
||||
|
@ -58,6 +58,7 @@ NO_GPU(Exp)
|
||||
NO_GPU(ExpandDims)
|
||||
NO_GPU(Expm1)
|
||||
NO_GPU(FFT)
|
||||
NO_GPU(Flatten)
|
||||
NO_GPU(Floor)
|
||||
NO_GPU(Full)
|
||||
NO_GPU(Gather)
|
||||
@ -113,6 +114,7 @@ NO_GPU_MULTI(SVD)
|
||||
NO_GPU(Tan)
|
||||
NO_GPU(Tanh)
|
||||
NO_GPU(Transpose)
|
||||
NO_GPU(Unflatten)
|
||||
NO_GPU(Inverse)
|
||||
NO_GPU(Cholesky)
|
||||
NO_GPU_MULTI(Eigh)
|
||||
|
@ -82,6 +82,7 @@ bool allows_shapeless(const Primitive& p) {
|
||||
typeid(p) == typeid(Transpose) || typeid(p) == typeid(Concatenate) ||
|
||||
typeid(p) == typeid(Matmul) || typeid(p) == typeid(QuantizedMatmul) ||
|
||||
typeid(p) == typeid(Squeeze) || typeid(p) == typeid(ExpandDims) ||
|
||||
typeid(p) == typeid(Flatten) || typeid(p) == typeid(Unflatten) ||
|
||||
typeid(p) == typeid(fast::AffineQuantize) ||
|
||||
typeid(p) == typeid(fast::LayerNorm) ||
|
||||
typeid(p) == typeid(fast::RMSNorm) || typeid(p) == typeid(fast::RoPE) ||
|
||||
|
@ -614,7 +614,7 @@ array scaled_dot_product_attention(
|
||||
auto k = inputs[1];
|
||||
auto v = inputs[2];
|
||||
if (n_repeats > 1) {
|
||||
q = reshape(q, {B, n_kv_heads, n_repeats, L, -1}, s);
|
||||
q = unflatten(q, 1, {n_kv_heads, n_repeats}, s);
|
||||
k = expand_dims(k, 2, s);
|
||||
v = expand_dims(v, 2, s);
|
||||
}
|
||||
@ -629,7 +629,7 @@ array scaled_dot_product_attention(
|
||||
scores = softmax(scores, std::vector<int>{-1}, true, s);
|
||||
auto out = matmul(scores, v, s);
|
||||
if (n_repeats > 1) {
|
||||
out = reshape(out, {B, n_q_heads, L, -1}, s);
|
||||
out = flatten(out, 1, 2, s);
|
||||
}
|
||||
return std::vector<array>{out};
|
||||
};
|
||||
|
106
mlx/ops.cpp
106
mlx/ops.cpp
@ -267,7 +267,7 @@ array as_strided(
|
||||
std::make_shared<AsStrided>(
|
||||
to_stream(s), std::move(shape), std::move(strides), offset),
|
||||
// Force the input array to be contiguous.
|
||||
{reshape(std::move(a), {-1}, s)});
|
||||
{flatten(std::move(a), s)});
|
||||
}
|
||||
|
||||
array copy(array a, StreamOrDevice s /* = {} */) {
|
||||
@ -380,10 +380,9 @@ array reshape(const array& a, Shape shape, StreamOrDevice s /* = {} */) {
|
||||
|
||||
// Infer the shape
|
||||
if (size > 0) {
|
||||
auto q_and_r = std::ldiv(a.size(), size);
|
||||
if (infer_idx >= 0) {
|
||||
shape[infer_idx] = q_and_r.quot;
|
||||
size *= q_and_r.quot;
|
||||
shape[infer_idx] = a.size() / size;
|
||||
size *= shape[infer_idx];
|
||||
}
|
||||
} else if (infer_idx >= 0) {
|
||||
throw std::invalid_argument(
|
||||
@ -401,6 +400,59 @@ array reshape(const array& a, Shape shape, StreamOrDevice s /* = {} */) {
|
||||
return array(std::move(shape), a.dtype(), std::move(p), {a});
|
||||
}
|
||||
|
||||
array unflatten(
|
||||
const array& a,
|
||||
int axis,
|
||||
Shape shape,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (shape.empty()) {
|
||||
throw std::invalid_argument(
|
||||
"[unflatten] Shape to unflatten to cannot be empty.");
|
||||
}
|
||||
auto ndim = static_cast<int>(a.ndim());
|
||||
auto ax = axis < 0 ? axis + ndim : axis;
|
||||
if (ax < 0 || ax >= ndim) {
|
||||
std::ostringstream msg;
|
||||
msg << "[unflatten] Invalid axes " << ax << " for array with " << a.ndim()
|
||||
<< " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
size_t size = 1;
|
||||
int infer_idx = -1;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
if (shape[i] == -1) {
|
||||
if (infer_idx >= 0) {
|
||||
throw std::invalid_argument(
|
||||
"[Unflatten] Can only infer one dimension.");
|
||||
}
|
||||
infer_idx = i;
|
||||
} else {
|
||||
size *= shape[i];
|
||||
}
|
||||
}
|
||||
if (infer_idx >= 0) {
|
||||
shape[infer_idx] = a.shape(ax) / size;
|
||||
size *= shape[infer_idx];
|
||||
}
|
||||
if (size != a.shape(ax)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Unflatten] Cannot unflatten axis " << axis << " with size "
|
||||
<< a.shape(ax) << " into shape " << shape << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (shape.size() == 1) {
|
||||
return a;
|
||||
}
|
||||
|
||||
auto out_shape = Unflatten::output_shape(a, ax, shape);
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
a.dtype(),
|
||||
std::make_shared<Unflatten>(to_stream(s), ax, std::move(shape)),
|
||||
{a});
|
||||
}
|
||||
|
||||
array flatten(
|
||||
const array& a,
|
||||
int start_axis,
|
||||
@ -433,11 +485,11 @@ array flatten(
|
||||
if (start_ax == end_ax) {
|
||||
return a;
|
||||
}
|
||||
Shape new_shape(a.shape().begin(), a.shape().begin() + start_ax);
|
||||
new_shape.push_back(-1);
|
||||
new_shape.insert(
|
||||
new_shape.end(), a.shape().begin() + end_ax + 1, a.shape().end());
|
||||
return reshape(a, std::move(new_shape), s);
|
||||
return array(
|
||||
Flatten::output_shape(a, start_ax, end_ax),
|
||||
a.dtype(),
|
||||
std::make_shared<Flatten>(to_stream(s), start_ax, end_ax),
|
||||
{a});
|
||||
}
|
||||
|
||||
array flatten(const array& a, StreamOrDevice s /* = {} */) {
|
||||
@ -901,7 +953,7 @@ array concatenate(
|
||||
StreamOrDevice s /* = {} */) {
|
||||
std::vector<array> flat_inputs;
|
||||
for (auto& a : arrays) {
|
||||
flat_inputs.push_back(reshape(a, {-1}, s));
|
||||
flat_inputs.push_back(flatten(a, s));
|
||||
}
|
||||
return concatenate(flat_inputs, 0, s);
|
||||
}
|
||||
@ -2568,22 +2620,9 @@ array matmul(
|
||||
}
|
||||
|
||||
// We can batch the multiplication by reshaping a
|
||||
if (a.ndim() > 2 && b.ndim() == 2) {
|
||||
std::vector<int> out_shape = a.shape();
|
||||
a = reshape(a, {-1, out_shape.back()}, s);
|
||||
out_shape.back() = b.shape(-1);
|
||||
if (in_b.ndim() == 1) {
|
||||
out_shape.pop_back();
|
||||
}
|
||||
auto out = array(
|
||||
{a.shape(0), b.shape(1)},
|
||||
out_type,
|
||||
std::make_shared<Matmul>(to_stream(s)),
|
||||
{a, b});
|
||||
return reshape(out, out_shape, s);
|
||||
}
|
||||
|
||||
if (a.ndim() > 2 || b.ndim() > 2) {
|
||||
if (in_a.ndim() > 2 && in_b.ndim() <= 2) {
|
||||
a = flatten(a, 0, -2, s);
|
||||
} else if (in_b.ndim() > 2) {
|
||||
Shape bsx_a(a.shape().begin(), a.shape().end() - 2);
|
||||
Shape bsx_b(b.shape().begin(), b.shape().end() - 2);
|
||||
auto inner_shape = broadcast_shapes(bsx_a, bsx_b);
|
||||
@ -2607,6 +2646,11 @@ array matmul(
|
||||
out_type,
|
||||
std::make_shared<Matmul>(to_stream(s)),
|
||||
{a, b});
|
||||
if (in_a.ndim() > 2 && in_b.ndim() <= 2) {
|
||||
auto orig_shape = in_a.shape();
|
||||
orig_shape.pop_back();
|
||||
out = unflatten(out, 0, std::move(orig_shape), s);
|
||||
}
|
||||
|
||||
// Remove the possibly inserted singleton dimensions
|
||||
std::vector<int> axes;
|
||||
@ -2753,7 +2797,7 @@ array take(
|
||||
}
|
||||
|
||||
array take(const array& a, const array& indices, StreamOrDevice s /* = {} */) {
|
||||
return take(reshape(a, {-1}, s), indices, 0, s);
|
||||
return take(flatten(a, s), indices, 0, s);
|
||||
}
|
||||
|
||||
array take(const array& a, int index, int axis, StreamOrDevice s /* = {} */) {
|
||||
@ -2783,7 +2827,7 @@ array take(const array& a, int index, int axis, StreamOrDevice s /* = {} */) {
|
||||
}
|
||||
|
||||
array take(const array& a, int index, StreamOrDevice s /* = {} */) {
|
||||
return take(reshape(a, {-1}, s), index, 0, s);
|
||||
return take(flatten(a, s), index, 0, s);
|
||||
}
|
||||
|
||||
array take_along_axis(
|
||||
@ -3853,11 +3897,11 @@ array addmm(
|
||||
|
||||
if (a.ndim() == 1) {
|
||||
// Insert a singleton dim in the beginning
|
||||
a = reshape(a, {1, -1}, s);
|
||||
a = expand_dims(a, 0, s);
|
||||
}
|
||||
if (b.ndim() == 1) {
|
||||
// Insert a singleton dim at the end
|
||||
b = reshape(b, {-1, 1}, s);
|
||||
b = expand_dims(b, 1, s);
|
||||
}
|
||||
|
||||
if (a.shape(-1) != b.shape(-2)) {
|
||||
@ -4644,7 +4688,7 @@ array roll(
|
||||
array roll(const array& a, int shift, StreamOrDevice s /* = {} */) {
|
||||
auto shape = a.shape();
|
||||
return reshape(
|
||||
roll(reshape(a, Shape{-1}, s), Shape{shift}, std::vector<int>{0}, s),
|
||||
roll(flatten(a, s), Shape{shift}, std::vector<int>{0}, s),
|
||||
std::move(shape),
|
||||
s);
|
||||
}
|
||||
|
@ -117,6 +117,9 @@ array triu(array x, int k = 0, StreamOrDevice s = {});
|
||||
/** Reshape an array to the given shape. */
|
||||
array reshape(const array& a, Shape shape, StreamOrDevice s = {});
|
||||
|
||||
/** Unflatten the axis to the given shape. */
|
||||
array unflatten(const array& a, int axis, Shape shape, StreamOrDevice s = {});
|
||||
|
||||
/** Flatten the dimensions in the range `[start_axis, end_axis]` . */
|
||||
array flatten(
|
||||
const array& a,
|
||||
|
@ -1651,12 +1651,114 @@ std::vector<Shape> ExpandDims::output_shapes(const std::vector<array>& inputs) {
|
||||
return {ExpandDims::output_shape(inputs[0], axes_)};
|
||||
}
|
||||
|
||||
std::vector<array> Flatten::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>&,
|
||||
const std::vector<array>&) {
|
||||
auto& in = primals[0];
|
||||
Shape unflatten_shape(
|
||||
in.shape().begin() + start_axis_, in.shape().begin() + end_axis_ + 1);
|
||||
return {unflatten(
|
||||
cotangents[0], start_axis_, std::move(unflatten_shape), stream())};
|
||||
}
|
||||
|
||||
std::vector<array> Flatten::jvp(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>&) {
|
||||
return {flatten(tangents[0], start_axis_, end_axis_, stream())};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Flatten::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto ax = axes[0];
|
||||
auto start_axis = start_axis_;
|
||||
auto end_axis = end_axis_;
|
||||
if (ax < start_axis) {
|
||||
start_axis++;
|
||||
end_axis++;
|
||||
} else {
|
||||
ax -= (end_axis - start_axis);
|
||||
}
|
||||
return {{flatten(inputs[0], start_axis, end_axis, stream())}, {ax}};
|
||||
}
|
||||
|
||||
bool Flatten::is_equivalent(const Primitive& other) const {
|
||||
const Flatten& a_other = static_cast<const Flatten&>(other);
|
||||
return start_axis_ == a_other.start_axis_ && end_axis_ == a_other.end_axis_;
|
||||
}
|
||||
|
||||
Shape Flatten::output_shape(const array& input, int start_axis, int end_axis) {
|
||||
Shape shape = input.shape();
|
||||
auto flat_size = input.shape(start_axis);
|
||||
for (int ax = start_axis + 1; ax <= end_axis; ++ax) {
|
||||
flat_size *= input.shape(ax);
|
||||
}
|
||||
shape.erase(shape.begin() + start_axis + 1, shape.begin() + end_axis + 1);
|
||||
shape[start_axis] = flat_size;
|
||||
return shape;
|
||||
}
|
||||
|
||||
std::vector<Shape> Flatten::output_shapes(const std::vector<array>& inputs) {
|
||||
return {Flatten::output_shape(inputs[0], start_axis_, end_axis_)};
|
||||
}
|
||||
|
||||
bool FFT::is_equivalent(const Primitive& other) const {
|
||||
const FFT& r_other = static_cast<const FFT&>(other);
|
||||
return axes_ == r_other.axes_ && inverse_ == r_other.inverse_ &&
|
||||
real_ == r_other.real_;
|
||||
}
|
||||
|
||||
std::vector<array> Unflatten::vjp(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>&,
|
||||
const std::vector<array>&) {
|
||||
return {flatten(cotangents[0], axis_, axis_ + shape_.size(), stream())};
|
||||
}
|
||||
|
||||
std::vector<array> Unflatten::jvp(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>&) {
|
||||
return {unflatten(tangents[0], axis_, shape_, stream())};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Unflatten::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto ax = axes[0];
|
||||
auto axis = axis_;
|
||||
if (ax <= axis_) {
|
||||
axis++;
|
||||
} else {
|
||||
ax += (shape_.size() - 1);
|
||||
}
|
||||
return {{unflatten(inputs[0], axis, shape_, stream())}, {ax}};
|
||||
}
|
||||
|
||||
bool Unflatten::is_equivalent(const Primitive& other) const {
|
||||
const auto& a_other = static_cast<const Unflatten&>(other);
|
||||
return axis_ == a_other.axis_ && shape_ == a_other.shape_;
|
||||
}
|
||||
|
||||
Shape Unflatten::output_shape(
|
||||
const array& input,
|
||||
int axis,
|
||||
const Shape& shape) {
|
||||
Shape out_shape = input.shape();
|
||||
out_shape[axis] = shape[0];
|
||||
out_shape.insert(
|
||||
out_shape.begin() + axis + 1, shape.begin() + 1, shape.end());
|
||||
return out_shape;
|
||||
}
|
||||
|
||||
std::vector<Shape> Unflatten::output_shapes(const std::vector<array>& inputs) {
|
||||
return {Unflatten::output_shape(inputs[0], axis_, shape_)};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> FFT::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
|
@ -1031,6 +1031,28 @@ class FFT : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Flatten : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Flatten(Stream stream, int start_axis, int end_axis)
|
||||
: UnaryPrimitive(stream), start_axis_(start_axis), end_axis_(end_axis) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Flatten)
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
static Shape output_shape(const array& input, int start_axis, int end_axis);
|
||||
|
||||
private:
|
||||
int start_axis_;
|
||||
int end_axis_;
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Floor : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Floor(Stream stream) : UnaryPrimitive(stream) {}
|
||||
@ -1643,16 +1665,6 @@ class Reshape : public UnaryPrimitive {
|
||||
|
||||
private:
|
||||
Shape shape_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
|
||||
static std::pair<bool, Strides> prepare_reshape(
|
||||
const array& in,
|
||||
const array& out);
|
||||
static void shared_buffer_reshape(
|
||||
const array& in,
|
||||
const Strides& out_strides,
|
||||
array& out);
|
||||
};
|
||||
|
||||
class Reduce : public UnaryPrimitive {
|
||||
@ -2137,6 +2149,28 @@ class Tanh : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Unflatten : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Unflatten(Stream stream, int axis, Shape shape)
|
||||
: UnaryPrimitive(stream), axis_(axis), shape_(std::move(shape)) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Unflatten)
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
static Shape output_shape(const array& input, int axis, const Shape& shape);
|
||||
|
||||
private:
|
||||
int axis_;
|
||||
Shape shape_;
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Uniform : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Uniform(Stream stream) : UnaryPrimitive(stream) {}
|
||||
|
@ -405,22 +405,22 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
|
||||
|
||||
// Unsqueeze handling
|
||||
if (unsqueeze_needed || squeeze_needed) {
|
||||
std::vector<int> out_shape;
|
||||
int axis = 0;
|
||||
for (auto& idx : remaining_indices) {
|
||||
std::vector<int> squeeze_axes;
|
||||
std::vector<int> unsqueeze_axes;
|
||||
for (int axis = 0; axis < remaining_indices.size(); ++axis) {
|
||||
auto& idx = remaining_indices[axis];
|
||||
if (unsqueeze_needed && idx.is_none()) {
|
||||
out_shape.push_back(1);
|
||||
unsqueeze_axes.push_back(axis - squeeze_axes.size());
|
||||
} else if (squeeze_needed && nb::isinstance<nb::int_>(idx)) {
|
||||
axis++;
|
||||
} else {
|
||||
out_shape.push_back(src.shape(axis++));
|
||||
squeeze_axes.push_back(axis - unsqueeze_axes.size());
|
||||
}
|
||||
}
|
||||
|
||||
out_shape.insert(
|
||||
out_shape.end(), src.shape().begin() + axis, src.shape().end());
|
||||
|
||||
src = reshape(src, out_shape);
|
||||
if (!squeeze_axes.empty()) {
|
||||
src = squeeze(src, std::move(squeeze_axes));
|
||||
}
|
||||
if (!unsqueeze_axes.empty()) {
|
||||
src = expand_dims(src, std::move(unsqueeze_axes));
|
||||
}
|
||||
}
|
||||
|
||||
return src;
|
||||
|
@ -103,6 +103,36 @@ void init_ops(nb::module_& m) {
|
||||
>>> mx.flatten(a, start_axis=0, end_axis=-1)
|
||||
array([1, 2, 3, 4], dtype=int32)
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"unflatten",
|
||||
&unflatten,
|
||||
nb::arg(),
|
||||
"axis"_a,
|
||||
"shape"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def unflatten(a: array, /, axis: int, shape: Sequence[int], *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Unflatten an axis of an array to a shape.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
axis (int): The axis to unflatten.
|
||||
shape (tuple(int)): The shape to unflatten to. At most one
|
||||
entry can be ``-1`` in which case the corresponding size will be
|
||||
inferred.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
array: The unflattened array.
|
||||
|
||||
Example:
|
||||
>>> a = mx.array([1, 2, 3, 4])
|
||||
>>> mx.unflatten(a, 0, (2, -1))
|
||||
array([[1, 2], [3, 4]], dtype=int32)
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"squeeze",
|
||||
[](const mx::array& a, const IntOrVec& v, const mx::StreamOrDevice& s) {
|
||||
|
@ -462,6 +462,22 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
cfun = mx.compile(fun, shapeless=True)
|
||||
self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))
|
||||
|
||||
def test_shapeless_compile_unflatten(self):
|
||||
x = mx.zeros((1, 1, 4 * 32))
|
||||
|
||||
def fun(x):
|
||||
return mx.unflatten(x, -1, (4, -1))
|
||||
|
||||
self.assertEqual(mx.compile(fun, shapeless=True)(x).shape, (1, 1, 4, 32))
|
||||
|
||||
def test_shapeless_compile_gather(self):
|
||||
x = mx.zeros((1, 1, 32))
|
||||
|
||||
def fun(x):
|
||||
return x[:, -1, :]
|
||||
|
||||
self.assertEqual(mx.compile(fun, shapeless=True)(x).shape, (1, 32))
|
||||
|
||||
def test_compile_with_constant(self):
|
||||
# Test float
|
||||
@partial(mx.compile)
|
||||
|
@ -163,6 +163,23 @@ TEST_CASE("test flatten") {
|
||||
CHECK_EQ(flatten(x, 0, 0).shape(), Shape({1}));
|
||||
}
|
||||
|
||||
TEST_CASE("test unflatten") {
|
||||
array x = array(1);
|
||||
CHECK_THROWS(unflatten(x, 0, {1, 1}));
|
||||
|
||||
x = array({1});
|
||||
auto out = unflatten(x, 0, {1, 1});
|
||||
CHECK_EQ(out.shape(), Shape({1, 1}));
|
||||
CHECK_THROWS(unflatten(x, 1, {1, 1}));
|
||||
CHECK_THROWS(unflatten(x, 0, {-1, -1}));
|
||||
CHECK_THROWS(unflatten(x, 0, {-1, 2}));
|
||||
CHECK_THROWS(unflatten(x, 0, {}));
|
||||
|
||||
x = zeros({4, 8});
|
||||
out = unflatten(x, 1, {2, 2, 2});
|
||||
CHECK_EQ(out.shape(), Shape({4, 2, 2, 2}));
|
||||
}
|
||||
|
||||
TEST_CASE("test squeeze and expand") {
|
||||
array x = zeros({2, 1, 2, 1, 2, 1});
|
||||
CHECK_EQ(squeeze(x).shape(), Shape{2, 2, 2});
|
||||
|
Loading…
Reference in New Issue
Block a user