ExpandDims primitive (#1687)

* add squeeze primitive

* simplify squeeze, use in gather

* fix

* fix

* fix

* fix

* fix no cpu

* use squeeze in matmul and friends

* expand dims primitive

* comment
This commit is contained in:
Awni Hannun 2024-12-10 16:39:07 -08:00 committed by GitHub
parent 310ad8d9db
commit f76a49e555
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 373 additions and 184 deletions

View File

@ -43,6 +43,7 @@ DEFAULT(NumberOfElements)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
DEFAULT(ExpandDims)
DEFAULT(FFT)
DEFAULT(Floor)
DEFAULT(Gather)
@ -76,6 +77,7 @@ DEFAULT(Slice)
DEFAULT(SliceUpdate)
DEFAULT_MULTI(Split)
DEFAULT(Sort)
DEFAULT(Squeeze)
DEFAULT(StopGradient)
DEFAULT_MULTI(SVD)
DEFAULT(Transpose)

View File

@ -85,6 +85,16 @@ void Depends::eval(
}
}
void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
auto strides = in.strides();
for (auto ax : axes_) {
strides.insert(strides.begin() + ax, 1);
}
move_or_copy(in, out, strides, in.flags(), in.data_size());
}
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
@ -248,6 +258,20 @@ void Split::eval(
}
}
void Squeeze::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
Strides strides;
for (int i = 0, j = 0; i < in.ndim(); ++i) {
if (j < axes_.size() && i == axes_[j]) {
j++;
} else {
strides.push_back(in.strides(i));
}
}
move_or_copy(in, out, strides, in.flags(), in.data_size());
}
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
move_or_copy(inputs[0], out);

View File

@ -57,6 +57,7 @@ DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
DEFAULT(Exp)
DEFAULT(ExpandDims)
DEFAULT(Expm1)
DEFAULT(FFT)
DEFAULT(Floor)
@ -101,6 +102,7 @@ DEFAULT(Softmax)
DEFAULT(Sort)
DEFAULT_MULTI(Split)
DEFAULT(Square)
DEFAULT(Squeeze)
DEFAULT(Sqrt)
DEFAULT(StopGradient)
DEFAULT(Subtract)

View File

@ -211,6 +211,10 @@ void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
copy_gpu(in, out, ctype);
}
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto read_task = [out = out,
@ -381,6 +385,10 @@ void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
/* const Stream& s = */ stream());
}
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}

View File

@ -55,6 +55,7 @@ NO_CPU(Equal)
NO_CPU(Erf)
NO_CPU(ErfInv)
NO_CPU(Exp)
NO_CPU(ExpandDims)
NO_CPU(Expm1)
NO_CPU(FFT)
NO_CPU(Floor)
@ -104,6 +105,7 @@ NO_CPU(Softmax)
NO_CPU(Sort)
NO_CPU_MULTI(Split)
NO_CPU(Square)
NO_CPU(Squeeze)
NO_CPU(Sqrt)
NO_CPU(StopGradient)
NO_CPU(Subtract)

View File

@ -55,6 +55,7 @@ NO_GPU(Equal)
NO_GPU(Erf)
NO_GPU(ErfInv)
NO_GPU(Exp)
NO_GPU(ExpandDims)
NO_GPU(Expm1)
NO_GPU(FFT)
NO_GPU(Floor)
@ -104,6 +105,7 @@ NO_GPU(Softmax)
NO_GPU(Sort)
NO_GPU_MULTI(Split)
NO_GPU(Square)
NO_GPU(Squeeze)
NO_GPU(Sqrt)
NO_GPU(StopGradient)
NO_GPU(Subtract)

View File

@ -81,6 +81,7 @@ bool allows_shapeless(const Primitive& p) {
typeid(p) == typeid(NumberOfElements) || typeid(p) == typeid(Gather) ||
typeid(p) == typeid(Transpose) || typeid(p) == typeid(Concatenate) ||
typeid(p) == typeid(Matmul) || typeid(p) == typeid(QuantizedMatmul) ||
typeid(p) == typeid(Squeeze) || typeid(p) == typeid(ExpandDims) ||
typeid(p) == typeid(fast::AffineQuantize) ||
typeid(p) == typeid(fast::LayerNorm) ||
typeid(p) == typeid(fast::RMSNorm) || typeid(p) == typeid(fast::RoPE) ||

View File

@ -20,7 +20,7 @@ namespace mlx::core {
namespace {
std::tuple<Shape, std::vector<int>, Shape, bool> compute_reduce_shape(
std::tuple<Shape, std::vector<int>, bool> compute_reduce_shape(
const std::vector<int>& axes,
const Shape& shape) {
bool is_noop = true;
@ -40,18 +40,16 @@ std::tuple<Shape, std::vector<int>, Shape, bool> compute_reduce_shape(
throw std::invalid_argument("Duplicate axes detected in reduction.");
}
Shape out_shape;
Shape squeezed_shape;
for (int i = 0; i < ndim; ++i) {
if (axes_set.count(i) == 0) {
out_shape.push_back(shape[i]);
squeezed_shape.push_back(shape[i]);
} else {
out_shape.push_back(1);
}
is_noop &= (out_shape.back() == shape[i]);
}
std::vector<int> sorted_axes(axes_set.begin(), axes_set.end());
return {out_shape, sorted_axes, squeezed_shape, is_noop};
return {out_shape, sorted_axes, is_noop};
}
Dtype at_least_float(const Dtype& d) {
@ -460,54 +458,51 @@ array hadamard_transform(
{astype(a, dtype, s)});
}
array squeeze_impl(
const array& a,
std::vector<int> axes,
StreamOrDevice s /* = {} */) {
for (auto& ax : axes) {
auto new_ax = ax < 0 ? ax + a.ndim() : ax;
if (new_ax < 0 || new_ax >= a.ndim()) {
std::ostringstream msg;
msg << "[squeeze] Invalid axes " << ax << " for array with " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
if (a.shape(new_ax) != 1) {
std::ostringstream msg;
msg << "[squeeze] Cannot squeeze axis " << ax << " with size "
<< a.shape(ax) << " which is not equal to 1.";
throw std::invalid_argument(msg.str());
}
ax = new_ax;
}
auto shape = Squeeze::output_shape(a, axes);
return array(
std::move(shape),
a.dtype(),
std::make_shared<Squeeze>(to_stream(s), std::move(axes)),
{a});
}
array squeeze(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s /* = {} */) {
std::set<int> unique_axes;
for (auto ax : axes) {
ax = ax < 0 ? ax + a.ndim() : ax;
if (ax < 0 || ax >= a.ndim()) {
std::ostringstream msg;
msg << "[squeeze] Invalid axes " << ax << " for array with " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
unique_axes.insert(ax < 0 ? ax + a.ndim() : ax);
}
if (a.shape(ax) != 1) {
std::ostringstream msg;
msg << "[squeeze] Cannot squeeze axis " << ax << " with size "
<< a.shape(ax) << " which is not equal to 1.";
throw std::invalid_argument(msg.str());
}
unique_axes.insert(ax);
}
if (unique_axes.size() != axes.size()) {
throw std::invalid_argument("[squeeze] Received duplicate axes.");
}
std::vector<int> sorted_axes(unique_axes.begin(), unique_axes.end());
Shape shape;
for (int i = 0, j = 0; i < a.ndim(); ++i) {
if (j < sorted_axes.size() && i == sorted_axes[j]) {
j++;
} else {
shape.push_back(a.shape(i));
}
}
return reshape(a, std::move(shape), s);
return squeeze_impl(a, std::move(sorted_axes), s);
}
array squeeze(const array& a, int axis, StreamOrDevice s /* = {} */) {
int ax = axis < 0 ? axis + a.ndim() : axis;
if (ax < 0 || ax >= a.ndim()) {
std::ostringstream msg;
msg << "[squeeze] Invalid axis " << axis << " for array with " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
auto shape = a.shape();
shape.erase(shape.begin() + ax);
return reshape(a, std::move(shape), s);
return squeeze_impl(a, {axis}, s);
}
array squeeze(const array& a, StreamOrDevice s /* = {} */) {
@ -517,21 +512,34 @@ array squeeze(const array& a, StreamOrDevice s /* = {} */) {
axes.push_back(i);
}
}
return squeeze(a, axes, s);
return squeeze_impl(a, std::move(axes), s);
}
array expand_dims(const array& a, int axis, StreamOrDevice s /* = {} */) {
int out_dim = a.ndim() + 1;
int ax = axis < 0 ? axis + out_dim : axis;
if (ax < 0 || ax >= out_dim) {
array expand_dims_impl(
const array& a,
std::vector<int> axes,
StreamOrDevice s /* = {} */) {
auto out_ndim = a.ndim() + axes.size();
for (auto& ax : axes) {
auto new_ax = ax < 0 ? ax + out_ndim : ax;
if (new_ax < 0 || new_ax >= out_ndim) {
std::ostringstream msg;
msg << "[expand_dims] Invalid axis " << axis << " for output array with "
msg << "[expand_dims] Invalid axis " << ax << " for output array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
auto shape = a.shape();
shape.insert(shape.begin() + ax, 1);
return reshape(a, std::move(shape), s);
ax = new_ax;
}
auto shape = ExpandDims::output_shape(a, axes);
return array(
std::move(shape),
a.dtype(),
std::make_shared<ExpandDims>(to_stream(s), std::move(axes)),
{a});
}
array expand_dims(const array& a, int axis, StreamOrDevice s /* = {} */) {
return expand_dims_impl(a, {axis}, s);
}
array expand_dims(
@ -544,31 +552,17 @@ array expand_dims(
throw std::invalid_argument("[expand_dims] Received duplicate axes.");
}
}
int out_ndim = axes.size() + a.ndim();
std::vector<int> canonical_axes = axes;
for (auto& ax : canonical_axes) {
ax = ax < 0 ? ax + out_ndim : ax;
if (ax < 0 || ax >= out_ndim) {
std::ostringstream msg;
msg << "[expand_dims] Invalid axis " << ax << " for output array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
}
// Check for repeats again
std::set<int> unique_axes(canonical_axes.begin(), canonical_axes.end());
auto out_ndim = a.ndim() + axes.size();
std::set<int> unique_axes;
for (auto ax : axes) {
unique_axes.insert(ax < 0 ? ax + out_ndim : ax);
}
if (unique_axes.size() != axes.size()) {
throw std::invalid_argument("[expand_dims] Received duplicate axes.");
}
std::vector<int> sorted_axes(unique_axes.begin(), unique_axes.end());
auto out_shape = a.shape();
for (int i = 0; i < sorted_axes.size(); ++i) {
out_shape.insert(out_shape.begin() + sorted_axes[i], 1);
}
return reshape(a, std::move(out_shape), s);
return expand_dims_impl(a, std::move(sorted_axes), s);
}
// Slice helper
@ -1519,7 +1513,7 @@ array all(
const std::vector<int>& axes,
bool keepdims /* = false */,
StreamOrDevice s /* = {}*/) {
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
auto [out_shape, sorted_axes, is_noop] =
compute_reduce_shape(axes, a.shape());
auto out = (is_noop)
? astype(a, bool_, s)
@ -1529,7 +1523,7 @@ array all(
std::make_shared<Reduce>(to_stream(s), Reduce::And, sorted_axes),
{a});
if (!keepdims) {
out = reshape(out, std::move(squeezed_shape), s);
out = squeeze(out, sorted_axes, s);
}
return out;
}
@ -1553,7 +1547,7 @@ array any(
const std::vector<int>& axes,
bool keepdims /* = false */,
StreamOrDevice s /* = {}*/) {
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
auto [out_shape, sorted_axes, is_noop] =
compute_reduce_shape(axes, a.shape());
auto out = (is_noop)
? astype(a, bool_, s)
@ -1563,7 +1557,7 @@ array any(
std::make_shared<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
{a});
if (!keepdims) {
out = reshape(out, std::move(squeezed_shape), s);
out = squeeze(out, sorted_axes, s);
}
return out;
}
@ -1590,7 +1584,7 @@ array sum(
if (axes.empty()) {
return a;
}
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
auto [out_shape, sorted_axes, is_noop] =
compute_reduce_shape(axes, a.shape());
Dtype out_type = a.dtype();
if (issubdtype(a.dtype(), signedinteger)) {
@ -1608,7 +1602,7 @@ array sum(
std::make_shared<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
{a});
if (!keepdims) {
out = reshape(out, std::move(squeezed_shape), s);
out = squeeze(out, sorted_axes, s);
}
return out;
}
@ -1742,7 +1736,7 @@ array prod(
if (axes.empty()) {
return a;
}
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
auto [out_shape, sorted_axes, is_noop] =
compute_reduce_shape(axes, a.shape());
Dtype out_type = a.dtype();
if (issubdtype(a.dtype(), signedinteger)) {
@ -1760,7 +1754,7 @@ array prod(
std::make_shared<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
{a});
if (!keepdims) {
out = reshape(out, std::move(squeezed_shape), s);
out = squeeze(out, sorted_axes, s);
}
return out;
}
@ -1787,7 +1781,7 @@ array max(
if (a.size() == 0) {
throw std::invalid_argument("[max] Cannot max reduce zero size array.");
}
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
auto [out_shape, sorted_axes, is_noop] =
compute_reduce_shape(axes, a.shape());
auto out = (is_noop)
? a
@ -1797,7 +1791,7 @@ array max(
std::make_shared<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
{a});
if (!keepdims) {
out = reshape(out, std::move(squeezed_shape), s);
out = squeeze(out, sorted_axes, s);
}
return out;
}
@ -1827,7 +1821,7 @@ array min(
if (axes.empty()) {
return a;
}
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
auto [out_shape, sorted_axes, is_noop] =
compute_reduce_shape(axes, a.shape());
auto out = (is_noop)
? a
@ -1837,7 +1831,7 @@ array min(
std::make_shared<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
{a});
if (!keepdims) {
out = reshape(out, std::move(squeezed_shape), s);
out = squeeze(out, sorted_axes, s);
}
return out;
}
@ -1870,7 +1864,7 @@ array argmin(
throw std::invalid_argument(
"[argmin] Cannot argmin reduce zero size array.");
}
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
auto [out_shape, sorted_axes, is_noop] =
compute_reduce_shape({axis}, a.shape());
auto out = (is_noop)
? zeros(out_shape, uint32, s)
@ -1881,7 +1875,7 @@ array argmin(
to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),
{a});
if (!keepdims) {
out = reshape(out, std::move(squeezed_shape), s);
out = squeeze(out, sorted_axes[0], s);
}
return out;
}
@ -1906,7 +1900,7 @@ array argmax(
throw std::invalid_argument(
"[argmax] Cannot argmax reduce zero size array.");
}
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
auto [out_shape, sorted_axes, is_noop] =
compute_reduce_shape({axis}, a.shape());
auto out = (is_noop)
? zeros(out_shape, uint32, s)
@ -1917,7 +1911,7 @@ array argmax(
to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),
{a});
if (!keepdims) {
out = reshape(out, std::move(squeezed_shape), s);
out = squeeze(out, sorted_axes[0], s);
}
return out;
}
@ -2544,11 +2538,11 @@ array matmul(
}
if (a.ndim() == 1) {
// Insert a singleton dim in the beginning
a = reshape(a, {1, -1}, s);
a = expand_dims(a, 0, s);
}
if (b.ndim() == 1) {
// Insert a singleton dim at the end
b = reshape(b, {-1, 1}, s);
b = expand_dims(b, 1, s);
}
if (a.shape(-1) != b.shape(-2)) {
std::ostringstream msg;
@ -2608,17 +2602,21 @@ array matmul(
auto out_shape = a.shape();
out_shape.back() = b.shape(-1);
auto p = std::make_shared<Matmul>(to_stream(s));
auto out = array(
std::move(out_shape),
out_type,
std::make_shared<Matmul>(to_stream(s)),
{a, b});
// Remove the possibly inserted singleton dimensions
if (in_a.ndim() == 1 || in_b.ndim() == 1) {
auto out = array(out_shape, out_type, std::move(p), {a, b});
out_shape.erase(
out_shape.end() - ((in_a.ndim() == 1) ? 2 : 1),
out_shape.end() - ((in_b.ndim() == 1) ? 0 : 1));
return reshape(out, std::move(out_shape), s);
std::vector<int> axes;
if (in_a.ndim() == 1) {
axes.push_back(out.ndim() - 2);
}
return array(std::move(out_shape), out_type, std::move(p), {a, b});
if (in_b.ndim() == 1) {
axes.push_back(out.ndim() - 1);
}
return axes.empty() ? out : squeeze(out, axes, s);
}
array gather(
@ -2658,15 +2656,6 @@ array gather(
<< " for array with " << a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
for (int i = 0; i < a.ndim(); ++i) {
if (slice_sizes[i] < 0 || slice_sizes[i] > a.shape(i)) {
std::ostringstream msg;
msg << "[gather] Slice sizes must be in [0, a.shape(i)]. Got "
<< slice_sizes << " for array with shape " << a.shape() << ".";
throw std::invalid_argument(msg.str());
}
}
// Promote indices to the same type
auto dtype = result_type(indices);
if (issubdtype(dtype, inexact)) {
@ -2680,6 +2669,29 @@ array gather(
idx = astype(idx, dtype, s);
}
if (a.size() == 0) {
// Empty input, either the total slice size is 0 or the indices are empty
auto total_slice = std::accumulate(
slice_sizes.begin(), slice_sizes.end(), 1, std::multiplies<int64_t>{});
auto idx_size = !inputs.empty() ? inputs[0].size() : 1;
if (idx_size != 0 && total_slice != 0) {
std::ostringstream msg;
msg << "[gather] If the input is empty, either the indices must be"
<< " empty or the total slice size must be 0.";
throw std::invalid_argument(msg.str());
}
} else {
// Non-empty input, check slice sizes are valid
for (int i = 0; i < a.ndim(); ++i) {
if (slice_sizes[i] < 0 || slice_sizes[i] > a.shape(i)) {
std::ostringstream msg;
msg << "[gather] Slice sizes must be in [0, a.shape(i)]. Got "
<< slice_sizes << " for array with shape " << a.shape() << ".";
throw std::invalid_argument(msg.str());
}
}
}
Shape out_shape;
if (!inputs.empty()) {
out_shape = inputs[0].shape();
@ -2688,9 +2700,10 @@ array gather(
inputs.insert(inputs.begin(), a);
return array(
out_shape,
std::move(out_shape),
a.dtype(),
std::make_shared<Gather>(to_stream(s), axes, slice_sizes),
std::make_shared<Gather>(
to_stream(s), std::move(axes), std::move(slice_sizes)),
inputs);
}
@ -2719,7 +2732,7 @@ array take(
// Make slice sizes to pass to gather
Shape slice_sizes = a.shape();
slice_sizes[axis] = indices.size() > 0 ? 1 : 0;
slice_sizes[axis] = 1;
auto out = gather(a, indices, axis, slice_sizes, s);
@ -2736,9 +2749,7 @@ array take(
}
// Squeeze the axis we take over
auto out_shape = out.shape();
out_shape.erase(out_shape.begin() + indices.ndim() + axis);
return reshape(out, std::move(out_shape), s);
return squeeze(out, indices.ndim() + axis, s);
}
array take(const array& a, const array& indices, StreamOrDevice s /* = {} */) {
@ -2811,12 +2822,14 @@ array take_along_axis(
}
std::vector<int> dims(a.ndim());
std::iota(dims.begin(), dims.end(), 0);
Shape slice_sizes(a.ndim(), a.size() > 0);
Shape slice_sizes(a.ndim(), 1);
auto out = gather(a, nd_indices, dims, slice_sizes, s);
// Squeeze out the slice shape
Shape out_shape(out.shape().begin(), out.shape().begin() + a.ndim());
return reshape(out, std::move(out_shape), s);
for (auto& d : dims) {
d += a.ndim();
}
return squeeze(out, dims, s);
}
array put_along_axis(
@ -3935,17 +3948,20 @@ array addmm(
}
auto out = array(
out_shape,
std::move(out_shape),
out_type,
std::make_shared<AddMM>(to_stream(s), alpha, beta),
{a, b, c});
// Remove the possibly inserted singleton dimensions
if (in_a_ndim == 1 || in_b_ndim == 1) {
out = reshape(out, out_shape_adjusted, s);
std::vector<int> axes;
if (in_a_ndim == 1) {
axes.push_back(out.ndim() - 2);
}
return out;
if (in_b_ndim == 1) {
axes.push_back(out.ndim() - 1);
}
return axes.empty() ? out : squeeze(out, axes, s);
}
/** Compute matrix product with tile-level masking */
@ -3986,11 +4002,11 @@ array block_masked_mm(
if (a.ndim() == 1) {
// Insert a singleton dim in the beginning
a = reshape(a, {1, -1}, s);
a = expand_dims(a, 0, s);
}
if (b.ndim() == 1) {
// Insert a singleton dim at the end
b = reshape(b, {-1, 1}, s);
b = expand_dims(b, 1, s);
}
if (a.shape(-1) != b.shape(-2)) {
@ -4110,20 +4126,19 @@ array block_masked_mm(
// Caculate array
auto out = array(
out_shape,
std::move(out_shape),
out_type,
std::make_shared<BlockMaskedMM>(to_stream(s), block_size),
std::move(inputs));
// Remove the possibly inserted singleton dimensions
if (in_a_ndim == 1 || in_b_ndim == 1) {
out_shape.erase(
out_shape.end() - ((in_a_ndim == 1) ? 2 : 1),
out_shape.end() - ((in_b_ndim == 1) ? 0 : 1));
out = reshape(out, out_shape, s);
std::vector<int> axes;
if (in_a_ndim == 1) {
axes.push_back(out.ndim() - 2);
}
return out;
if (in_b_ndim == 1) {
axes.push_back(out.ndim() - 1);
}
return axes.empty() ? out : squeeze(out, axes, s);
}
/** Compute matrix product with matrix-level gather */
@ -4150,11 +4165,11 @@ array gather_mm(
if (a.ndim() == 1) {
// Insert a singleton dim in the beginning
a = reshape(a, {1, -1}, s);
a = expand_dims(a, 0, s);
}
if (b.ndim() == 1) {
// Insert a singleton dim at the end
b = reshape(b, {-1, 1}, s);
b = expand_dims(b, 1, s);
}
if (a.shape(-1) != b.shape(-2)) {
@ -4212,20 +4227,20 @@ array gather_mm(
// Caculate array
auto out = array(
out_shape,
std::move(out_shape),
out_type,
std::make_shared<GatherMM>(to_stream(s)),
{a, b, lhs_indices, rhs_indices});
// Remove the possibly inserted singleton dimensions
if (in_a_ndim == 1 || in_b_ndim == 1) {
out_shape.erase(
out_shape.end() - ((in_a_ndim == 1) ? 2 : 1),
out_shape.end() - ((in_b_ndim == 1) ? 0 : 1));
out = reshape(out, out_shape, s);
std::vector<int> axes;
if (in_a_ndim == 1) {
axes.push_back(out.ndim() - 2);
}
return out;
if (in_b_ndim == 1) {
axes.push_back(out.ndim() - 1);
}
return axes.empty() ? out : squeeze(out, axes, s);
}
array diagonal(

View File

@ -1602,6 +1602,55 @@ std::pair<std::vector<array>, std::vector<int>> Expm1::vmap(
return {{expm1(inputs[0], stream())}, axes};
}
std::vector<array> ExpandDims::vjp(
const std::vector<array>&,
const std::vector<array>& cotangents,
const std::vector<int>&,
const std::vector<array>&) {
return {squeeze(cotangents[0], axes_, stream())};
}
std::vector<array> ExpandDims::jvp(
const std::vector<array>&,
const std::vector<array>& tangents,
const std::vector<int>&) {
return {expand_dims(tangents[0], axes_, stream())};
}
std::pair<std::vector<array>, std::vector<int>> ExpandDims::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto ax = axes[0];
auto expand_axes = axes_;
for (auto& s : expand_axes) {
if (s >= axes[0]) {
s++;
} else {
ax++;
}
}
return {{expand_dims(inputs[0], std::move(expand_axes), stream())}, {ax}};
}
bool ExpandDims::is_equivalent(const Primitive& other) const {
const ExpandDims& a_other = static_cast<const ExpandDims&>(other);
return (axes_ == a_other.axes_);
}
Shape ExpandDims::output_shape(
const array& input,
const std::vector<int>& axes) {
auto shape = input.shape();
for (auto ax : axes) {
shape.insert(shape.begin() + ax, 1);
}
return shape;
}
std::vector<Shape> ExpandDims::output_shapes(const std::vector<array>& inputs) {
return {ExpandDims::output_shape(inputs[0], axes_)};
}
bool FFT::is_equivalent(const Primitive& other) const {
const FFT& r_other = static_cast<const FFT&>(other);
return axes_ == r_other.axes_ && inverse_ == r_other.inverse_ &&
@ -1846,7 +1895,7 @@ bool Gather::is_equivalent(const Primitive& other) const {
std::vector<Shape> Gather::output_shapes(const std::vector<array>& inputs) {
Shape out_shape;
if (inputs.size() > 1) {
out_shape = inputs[0].shape();
out_shape = inputs[1].shape();
}
out_shape.insert(out_shape.end(), slice_sizes_.begin(), slice_sizes_.end());
return {std::move(out_shape)};
@ -3847,6 +3896,57 @@ std::pair<std::vector<array>, std::vector<int>> Subtract::vmap(
return {{subtract(a, b, stream())}, {to_ax}};
}
std::vector<array> Squeeze::vjp(
const std::vector<array>&,
const std::vector<array>& cotangents,
const std::vector<int>&,
const std::vector<array>&) {
return {expand_dims(cotangents[0], axes_, stream())};
}
std::vector<array> Squeeze::jvp(
const std::vector<array>&,
const std::vector<array>& tangents,
const std::vector<int>&) {
return {squeeze(tangents[0], axes_, stream())};
}
std::pair<std::vector<array>, std::vector<int>> Squeeze::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto ax = axes[0];
auto squeeze_axes = axes_;
for (auto& s : squeeze_axes) {
if (s >= axes[0]) {
s++;
} else {
ax--;
}
}
return {{squeeze(inputs[0], std::move(squeeze_axes), stream())}, {ax}};
}
bool Squeeze::is_equivalent(const Primitive& other) const {
const Squeeze& a_other = static_cast<const Squeeze&>(other);
return (axes_ == a_other.axes_);
}
Shape Squeeze::output_shape(const array& input, const std::vector<int>& axes) {
Shape shape;
for (int i = 0, j = 0; i < input.ndim(); ++i) {
if (j < axes.size() && i == axes[j]) {
j++;
} else {
shape.push_back(input.shape(i));
}
}
return shape;
}
std::vector<Shape> Squeeze::output_shapes(const std::vector<array>& inputs) {
return {Squeeze::output_shape(inputs[0], axes_)};
}
std::vector<array> Tan::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,

View File

@ -983,6 +983,28 @@ class Expm1 : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class ExpandDims : public UnaryPrimitive {
public:
explicit ExpandDims(Stream stream, std::vector<int> axes)
: UnaryPrimitive(stream), axes_(std::move(axes)) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(ExpandDims)
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
bool is_equivalent(const Primitive& other) const override;
static Shape output_shape(const array& input, const std::vector<int>& axes);
private:
void eval(const std::vector<array>& inputs, array& out);
std::vector<int> axes_;
};
class FFT : public UnaryPrimitive {
public:
explicit FFT(
@ -1046,9 +1068,11 @@ class Gather : public UnaryPrimitive {
public:
explicit Gather(
Stream stream,
const std::vector<int>& axes,
const std::vector<int>& slice_sizes)
: UnaryPrimitive(stream), axes_(axes), slice_sizes_(slice_sizes) {}
std::vector<int> axes,
std::vector<int> slice_sizes)
: UnaryPrimitive(stream),
axes_(std::move(axes)),
slice_sizes_(std::move(slice_sizes)) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -2057,6 +2081,28 @@ class Subtract : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class Squeeze : public UnaryPrimitive {
public:
explicit Squeeze(Stream stream, std::vector<int> axes)
: UnaryPrimitive(stream), axes_(std::move(axes)) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Squeeze)
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
bool is_equivalent(const Primitive& other) const override;
static Shape output_shape(const array& input, const std::vector<int>& axes);
private:
void eval(const std::vector<array>& inputs, array& out);
std::vector<int> axes_;
};
class Tan : public UnaryPrimitive {
public:
explicit Tan(Stream stream) : UnaryPrimitive(stream) {}

View File

@ -144,23 +144,23 @@ array mlx_gather_nd(
int slice_index = 0;
for (int i = 0; i < gather_indices.size(); i++) {
if (is_slice[i]) {
std::vector<int> index_shape(max_dims + num_slices, 1);
Shape index_shape(max_dims + num_slices, 1);
index_shape[max_dims + slice_index] = gather_indices[i].shape(0);
gather_indices[i] = reshape(gather_indices[i], index_shape);
gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
slice_index++;
} else {
std::vector<int> index_shape = gather_indices[i].shape();
auto index_shape = gather_indices[i].shape();
index_shape.insert(index_shape.end(), num_slices, 1);
gather_indices[i] = reshape(gather_indices[i], index_shape);
gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
}
}
} else {
// reshape them so that the int/array indices are last
for (int i = 0; i < gather_indices.size(); i++) {
if (i < num_slices) {
std::vector<int> index_shape(max_dims + num_slices, 1);
Shape index_shape(max_dims + num_slices, 1);
index_shape[i] = gather_indices[i].shape(0);
gather_indices[i] = reshape(gather_indices[i], index_shape);
gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
}
}
}
@ -172,19 +172,11 @@ array mlx_gather_nd(
std::fill(slice_sizes.begin(), slice_sizes.begin() + indices.size(), 1);
src = gather(src, gather_indices, axes, slice_sizes);
// Squeeze the dims
std::vector<int> out_shape;
out_shape.insert(
out_shape.end(),
src.shape().begin(),
src.shape().begin() + max_dims + num_slices);
out_shape.insert(
out_shape.end(),
src.shape().begin() + max_dims + num_slices + indices.size(),
src.shape().end());
src = reshape(src, out_shape);
return src;
// Squeeze the array index dims
for (auto& ax : axes) {
ax += max_dims + num_slices;
}
return squeeze(src, axes);
}
auto mlx_expand_ellipsis(

View File

@ -392,27 +392,6 @@ class TestCompile(mlx_tests.MLXTestCase):
out = fun(x, y=y, z=z)
self.assertEqual(out.item(), 6)
def test_shapeless_compile(self):
y = 1
@partial(mx.compile, shapeless=True)
def fun(x):
return x + y
x = mx.array([1, 2])
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3])))
# The function is not recompiled, so the change
# to y should not be reflected in the output
y = 2
x = mx.array([1, 2, 3])
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3, 4])))
# Type change recompiles
x = mx.array([1.0, 2.0, 3.0])
self.assertTrue(mx.array_equal(fun(x), mx.array([3.0, 4.0, 5.0])))
fun(x, y=y, z=z)
def test_shapeless_compile(self):
y = 1
@ -477,6 +456,12 @@ class TestCompile(mlx_tests.MLXTestCase):
mx.eval(cfun(x1))
self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))
def fun(x):
return x * x.sum(-1, keepdims=False)
cfun = mx.compile(fun, shapeless=True)
self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))
def test_compile_with_constant(self):
# Test float
@partial(mx.compile)
@ -809,6 +794,13 @@ class TestCompile(mlx_tests.MLXTestCase):
out = fun(*inputs)
self.assertTrue(mx.allclose(out, mx.full((2, 2), 20)))
def test_shapeless_compile_matmul(self):
a = mx.array([0.0, 1.0, 2.0])
b = mx.array([0.0, 1.0, 2.0])
fun = mx.compile(lambda a, b: a @ b, shapeless=True)
self.assertTrue(mx.allclose(fun(a, b), a @ b))
if __name__ == "__main__":
unittest.main()

View File

@ -1835,6 +1835,9 @@ TEST_CASE("test broadcast") {
}
TEST_CASE("test gather") {
// Empty input, non-empty indices/slice
CHECK_THROWS(gather(array({}), array({1}), 0, {1}));
// More indices than dimensions
CHECK_THROWS(gather(array(0), array({1}), 0, {1}));