ExpandDims primitive (#1687)

* add squeeze primitive

* simplify squeeze, use in gather

* fix

* fix

* fix

* fix

* fix no cpu

* use squeeze in matmul and friends

* expand dims primitive

* comment
This commit is contained in:
Awni Hannun 2024-12-10 16:39:07 -08:00 committed by GitHub
parent 310ad8d9db
commit f76a49e555
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 373 additions and 184 deletions

View File

@ -43,6 +43,7 @@ DEFAULT(NumberOfElements)
DEFAULT(Equal) DEFAULT(Equal)
DEFAULT(Erf) DEFAULT(Erf)
DEFAULT(ErfInv) DEFAULT(ErfInv)
DEFAULT(ExpandDims)
DEFAULT(FFT) DEFAULT(FFT)
DEFAULT(Floor) DEFAULT(Floor)
DEFAULT(Gather) DEFAULT(Gather)
@ -76,6 +77,7 @@ DEFAULT(Slice)
DEFAULT(SliceUpdate) DEFAULT(SliceUpdate)
DEFAULT_MULTI(Split) DEFAULT_MULTI(Split)
DEFAULT(Sort) DEFAULT(Sort)
DEFAULT(Squeeze)
DEFAULT(StopGradient) DEFAULT(StopGradient)
DEFAULT_MULTI(SVD) DEFAULT_MULTI(SVD)
DEFAULT(Transpose) DEFAULT(Transpose)

View File

@ -85,6 +85,16 @@ void Depends::eval(
} }
} }
void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
auto strides = in.strides();
for (auto ax : axes_) {
strides.insert(strides.begin() + ax, 1);
}
move_or_copy(in, out, strides, in.flags(), in.data_size());
}
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) { void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
@ -248,6 +258,20 @@ void Split::eval(
} }
} }
void Squeeze::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
Strides strides;
for (int i = 0, j = 0; i < in.ndim(); ++i) {
if (j < axes_.size() && i == axes_[j]) {
j++;
} else {
strides.push_back(in.strides(i));
}
}
move_or_copy(in, out, strides, in.flags(), in.data_size());
}
void StopGradient::eval(const std::vector<array>& inputs, array& out) { void StopGradient::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
move_or_copy(inputs[0], out); move_or_copy(inputs[0], out);

View File

@ -57,6 +57,7 @@ DEFAULT(Equal)
DEFAULT(Erf) DEFAULT(Erf)
DEFAULT(ErfInv) DEFAULT(ErfInv)
DEFAULT(Exp) DEFAULT(Exp)
DEFAULT(ExpandDims)
DEFAULT(Expm1) DEFAULT(Expm1)
DEFAULT(FFT) DEFAULT(FFT)
DEFAULT(Floor) DEFAULT(Floor)
@ -101,6 +102,7 @@ DEFAULT(Softmax)
DEFAULT(Sort) DEFAULT(Sort)
DEFAULT_MULTI(Split) DEFAULT_MULTI(Split)
DEFAULT(Square) DEFAULT(Square)
DEFAULT(Squeeze)
DEFAULT(Sqrt) DEFAULT(Sqrt)
DEFAULT(StopGradient) DEFAULT(StopGradient)
DEFAULT(Subtract) DEFAULT(Subtract)

View File

@ -211,6 +211,10 @@ void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
copy_gpu(in, out, ctype); copy_gpu(in, out, ctype);
} }
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Load::eval_gpu(const std::vector<array>& inputs, array& out) { void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto read_task = [out = out, auto read_task = [out = out,
@ -381,6 +385,10 @@ void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
/* const Stream& s = */ stream()); /* const Stream& s = */ stream());
} }
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) { void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out); eval(inputs, out);
} }

View File

@ -55,6 +55,7 @@ NO_CPU(Equal)
NO_CPU(Erf) NO_CPU(Erf)
NO_CPU(ErfInv) NO_CPU(ErfInv)
NO_CPU(Exp) NO_CPU(Exp)
NO_CPU(ExpandDims)
NO_CPU(Expm1) NO_CPU(Expm1)
NO_CPU(FFT) NO_CPU(FFT)
NO_CPU(Floor) NO_CPU(Floor)
@ -104,6 +105,7 @@ NO_CPU(Softmax)
NO_CPU(Sort) NO_CPU(Sort)
NO_CPU_MULTI(Split) NO_CPU_MULTI(Split)
NO_CPU(Square) NO_CPU(Square)
NO_CPU(Squeeze)
NO_CPU(Sqrt) NO_CPU(Sqrt)
NO_CPU(StopGradient) NO_CPU(StopGradient)
NO_CPU(Subtract) NO_CPU(Subtract)

View File

@ -55,6 +55,7 @@ NO_GPU(Equal)
NO_GPU(Erf) NO_GPU(Erf)
NO_GPU(ErfInv) NO_GPU(ErfInv)
NO_GPU(Exp) NO_GPU(Exp)
NO_GPU(ExpandDims)
NO_GPU(Expm1) NO_GPU(Expm1)
NO_GPU(FFT) NO_GPU(FFT)
NO_GPU(Floor) NO_GPU(Floor)
@ -104,6 +105,7 @@ NO_GPU(Softmax)
NO_GPU(Sort) NO_GPU(Sort)
NO_GPU_MULTI(Split) NO_GPU_MULTI(Split)
NO_GPU(Square) NO_GPU(Square)
NO_GPU(Squeeze)
NO_GPU(Sqrt) NO_GPU(Sqrt)
NO_GPU(StopGradient) NO_GPU(StopGradient)
NO_GPU(Subtract) NO_GPU(Subtract)

View File

@ -81,6 +81,7 @@ bool allows_shapeless(const Primitive& p) {
typeid(p) == typeid(NumberOfElements) || typeid(p) == typeid(Gather) || typeid(p) == typeid(NumberOfElements) || typeid(p) == typeid(Gather) ||
typeid(p) == typeid(Transpose) || typeid(p) == typeid(Concatenate) || typeid(p) == typeid(Transpose) || typeid(p) == typeid(Concatenate) ||
typeid(p) == typeid(Matmul) || typeid(p) == typeid(QuantizedMatmul) || typeid(p) == typeid(Matmul) || typeid(p) == typeid(QuantizedMatmul) ||
typeid(p) == typeid(Squeeze) || typeid(p) == typeid(ExpandDims) ||
typeid(p) == typeid(fast::AffineQuantize) || typeid(p) == typeid(fast::AffineQuantize) ||
typeid(p) == typeid(fast::LayerNorm) || typeid(p) == typeid(fast::LayerNorm) ||
typeid(p) == typeid(fast::RMSNorm) || typeid(p) == typeid(fast::RoPE) || typeid(p) == typeid(fast::RMSNorm) || typeid(p) == typeid(fast::RoPE) ||

View File

@ -20,7 +20,7 @@ namespace mlx::core {
namespace { namespace {
std::tuple<Shape, std::vector<int>, Shape, bool> compute_reduce_shape( std::tuple<Shape, std::vector<int>, bool> compute_reduce_shape(
const std::vector<int>& axes, const std::vector<int>& axes,
const Shape& shape) { const Shape& shape) {
bool is_noop = true; bool is_noop = true;
@ -40,18 +40,16 @@ std::tuple<Shape, std::vector<int>, Shape, bool> compute_reduce_shape(
throw std::invalid_argument("Duplicate axes detected in reduction."); throw std::invalid_argument("Duplicate axes detected in reduction.");
} }
Shape out_shape; Shape out_shape;
Shape squeezed_shape;
for (int i = 0; i < ndim; ++i) { for (int i = 0; i < ndim; ++i) {
if (axes_set.count(i) == 0) { if (axes_set.count(i) == 0) {
out_shape.push_back(shape[i]); out_shape.push_back(shape[i]);
squeezed_shape.push_back(shape[i]);
} else { } else {
out_shape.push_back(1); out_shape.push_back(1);
} }
is_noop &= (out_shape.back() == shape[i]); is_noop &= (out_shape.back() == shape[i]);
} }
std::vector<int> sorted_axes(axes_set.begin(), axes_set.end()); std::vector<int> sorted_axes(axes_set.begin(), axes_set.end());
return {out_shape, sorted_axes, squeezed_shape, is_noop}; return {out_shape, sorted_axes, is_noop};
} }
Dtype at_least_float(const Dtype& d) { Dtype at_least_float(const Dtype& d) {
@ -460,54 +458,51 @@ array hadamard_transform(
{astype(a, dtype, s)}); {astype(a, dtype, s)});
} }
array squeeze_impl(
const array& a,
std::vector<int> axes,
StreamOrDevice s /* = {} */) {
for (auto& ax : axes) {
auto new_ax = ax < 0 ? ax + a.ndim() : ax;
if (new_ax < 0 || new_ax >= a.ndim()) {
std::ostringstream msg;
msg << "[squeeze] Invalid axes " << ax << " for array with " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
if (a.shape(new_ax) != 1) {
std::ostringstream msg;
msg << "[squeeze] Cannot squeeze axis " << ax << " with size "
<< a.shape(ax) << " which is not equal to 1.";
throw std::invalid_argument(msg.str());
}
ax = new_ax;
}
auto shape = Squeeze::output_shape(a, axes);
return array(
std::move(shape),
a.dtype(),
std::make_shared<Squeeze>(to_stream(s), std::move(axes)),
{a});
}
array squeeze( array squeeze(
const array& a, const array& a,
const std::vector<int>& axes, const std::vector<int>& axes,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
std::set<int> unique_axes; std::set<int> unique_axes;
for (auto ax : axes) { for (auto ax : axes) {
ax = ax < 0 ? ax + a.ndim() : ax; unique_axes.insert(ax < 0 ? ax + a.ndim() : ax);
if (ax < 0 || ax >= a.ndim()) {
std::ostringstream msg;
msg << "[squeeze] Invalid axes " << ax << " for array with " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
if (a.shape(ax) != 1) {
std::ostringstream msg;
msg << "[squeeze] Cannot squeeze axis " << ax << " with size "
<< a.shape(ax) << " which is not equal to 1.";
throw std::invalid_argument(msg.str());
}
unique_axes.insert(ax);
} }
if (unique_axes.size() != axes.size()) { if (unique_axes.size() != axes.size()) {
throw std::invalid_argument("[squeeze] Received duplicate axes."); throw std::invalid_argument("[squeeze] Received duplicate axes.");
} }
std::vector<int> sorted_axes(unique_axes.begin(), unique_axes.end()); std::vector<int> sorted_axes(unique_axes.begin(), unique_axes.end());
Shape shape; return squeeze_impl(a, std::move(sorted_axes), s);
for (int i = 0, j = 0; i < a.ndim(); ++i) {
if (j < sorted_axes.size() && i == sorted_axes[j]) {
j++;
} else {
shape.push_back(a.shape(i));
}
}
return reshape(a, std::move(shape), s);
} }
array squeeze(const array& a, int axis, StreamOrDevice s /* = {} */) { array squeeze(const array& a, int axis, StreamOrDevice s /* = {} */) {
int ax = axis < 0 ? axis + a.ndim() : axis; return squeeze_impl(a, {axis}, s);
if (ax < 0 || ax >= a.ndim()) {
std::ostringstream msg;
msg << "[squeeze] Invalid axis " << axis << " for array with " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
auto shape = a.shape();
shape.erase(shape.begin() + ax);
return reshape(a, std::move(shape), s);
} }
array squeeze(const array& a, StreamOrDevice s /* = {} */) { array squeeze(const array& a, StreamOrDevice s /* = {} */) {
@ -517,21 +512,34 @@ array squeeze(const array& a, StreamOrDevice s /* = {} */) {
axes.push_back(i); axes.push_back(i);
} }
} }
return squeeze(a, axes, s); return squeeze_impl(a, std::move(axes), s);
}
array expand_dims_impl(
const array& a,
std::vector<int> axes,
StreamOrDevice s /* = {} */) {
auto out_ndim = a.ndim() + axes.size();
for (auto& ax : axes) {
auto new_ax = ax < 0 ? ax + out_ndim : ax;
if (new_ax < 0 || new_ax >= out_ndim) {
std::ostringstream msg;
msg << "[expand_dims] Invalid axis " << ax << " for output array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
ax = new_ax;
}
auto shape = ExpandDims::output_shape(a, axes);
return array(
std::move(shape),
a.dtype(),
std::make_shared<ExpandDims>(to_stream(s), std::move(axes)),
{a});
} }
array expand_dims(const array& a, int axis, StreamOrDevice s /* = {} */) { array expand_dims(const array& a, int axis, StreamOrDevice s /* = {} */) {
int out_dim = a.ndim() + 1; return expand_dims_impl(a, {axis}, s);
int ax = axis < 0 ? axis + out_dim : axis;
if (ax < 0 || ax >= out_dim) {
std::ostringstream msg;
msg << "[expand_dims] Invalid axis " << axis << " for output array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
auto shape = a.shape();
shape.insert(shape.begin() + ax, 1);
return reshape(a, std::move(shape), s);
} }
array expand_dims( array expand_dims(
@ -544,31 +552,17 @@ array expand_dims(
throw std::invalid_argument("[expand_dims] Received duplicate axes."); throw std::invalid_argument("[expand_dims] Received duplicate axes.");
} }
} }
int out_ndim = axes.size() + a.ndim();
std::vector<int> canonical_axes = axes;
for (auto& ax : canonical_axes) {
ax = ax < 0 ? ax + out_ndim : ax;
if (ax < 0 || ax >= out_ndim) {
std::ostringstream msg;
msg << "[expand_dims] Invalid axis " << ax << " for output array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
}
// Check for repeats again // Check for repeats again
std::set<int> unique_axes(canonical_axes.begin(), canonical_axes.end()); auto out_ndim = a.ndim() + axes.size();
std::set<int> unique_axes;
for (auto ax : axes) {
unique_axes.insert(ax < 0 ? ax + out_ndim : ax);
}
if (unique_axes.size() != axes.size()) { if (unique_axes.size() != axes.size()) {
throw std::invalid_argument("[expand_dims] Received duplicate axes."); throw std::invalid_argument("[expand_dims] Received duplicate axes.");
} }
std::vector<int> sorted_axes(unique_axes.begin(), unique_axes.end()); std::vector<int> sorted_axes(unique_axes.begin(), unique_axes.end());
auto out_shape = a.shape(); return expand_dims_impl(a, std::move(sorted_axes), s);
for (int i = 0; i < sorted_axes.size(); ++i) {
out_shape.insert(out_shape.begin() + sorted_axes[i], 1);
}
return reshape(a, std::move(out_shape), s);
} }
// Slice helper // Slice helper
@ -1519,7 +1513,7 @@ array all(
const std::vector<int>& axes, const std::vector<int>& axes,
bool keepdims /* = false */, bool keepdims /* = false */,
StreamOrDevice s /* = {}*/) { StreamOrDevice s /* = {}*/) {
auto [out_shape, sorted_axes, squeezed_shape, is_noop] = auto [out_shape, sorted_axes, is_noop] =
compute_reduce_shape(axes, a.shape()); compute_reduce_shape(axes, a.shape());
auto out = (is_noop) auto out = (is_noop)
? astype(a, bool_, s) ? astype(a, bool_, s)
@ -1529,7 +1523,7 @@ array all(
std::make_shared<Reduce>(to_stream(s), Reduce::And, sorted_axes), std::make_shared<Reduce>(to_stream(s), Reduce::And, sorted_axes),
{a}); {a});
if (!keepdims) { if (!keepdims) {
out = reshape(out, std::move(squeezed_shape), s); out = squeeze(out, sorted_axes, s);
} }
return out; return out;
} }
@ -1553,7 +1547,7 @@ array any(
const std::vector<int>& axes, const std::vector<int>& axes,
bool keepdims /* = false */, bool keepdims /* = false */,
StreamOrDevice s /* = {}*/) { StreamOrDevice s /* = {}*/) {
auto [out_shape, sorted_axes, squeezed_shape, is_noop] = auto [out_shape, sorted_axes, is_noop] =
compute_reduce_shape(axes, a.shape()); compute_reduce_shape(axes, a.shape());
auto out = (is_noop) auto out = (is_noop)
? astype(a, bool_, s) ? astype(a, bool_, s)
@ -1563,7 +1557,7 @@ array any(
std::make_shared<Reduce>(to_stream(s), Reduce::Or, sorted_axes), std::make_shared<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
{a}); {a});
if (!keepdims) { if (!keepdims) {
out = reshape(out, std::move(squeezed_shape), s); out = squeeze(out, sorted_axes, s);
} }
return out; return out;
} }
@ -1590,7 +1584,7 @@ array sum(
if (axes.empty()) { if (axes.empty()) {
return a; return a;
} }
auto [out_shape, sorted_axes, squeezed_shape, is_noop] = auto [out_shape, sorted_axes, is_noop] =
compute_reduce_shape(axes, a.shape()); compute_reduce_shape(axes, a.shape());
Dtype out_type = a.dtype(); Dtype out_type = a.dtype();
if (issubdtype(a.dtype(), signedinteger)) { if (issubdtype(a.dtype(), signedinteger)) {
@ -1608,7 +1602,7 @@ array sum(
std::make_shared<Reduce>(to_stream(s), Reduce::Sum, sorted_axes), std::make_shared<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
{a}); {a});
if (!keepdims) { if (!keepdims) {
out = reshape(out, std::move(squeezed_shape), s); out = squeeze(out, sorted_axes, s);
} }
return out; return out;
} }
@ -1742,7 +1736,7 @@ array prod(
if (axes.empty()) { if (axes.empty()) {
return a; return a;
} }
auto [out_shape, sorted_axes, squeezed_shape, is_noop] = auto [out_shape, sorted_axes, is_noop] =
compute_reduce_shape(axes, a.shape()); compute_reduce_shape(axes, a.shape());
Dtype out_type = a.dtype(); Dtype out_type = a.dtype();
if (issubdtype(a.dtype(), signedinteger)) { if (issubdtype(a.dtype(), signedinteger)) {
@ -1760,7 +1754,7 @@ array prod(
std::make_shared<Reduce>(to_stream(s), Reduce::Prod, sorted_axes), std::make_shared<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
{a}); {a});
if (!keepdims) { if (!keepdims) {
out = reshape(out, std::move(squeezed_shape), s); out = squeeze(out, sorted_axes, s);
} }
return out; return out;
} }
@ -1787,7 +1781,7 @@ array max(
if (a.size() == 0) { if (a.size() == 0) {
throw std::invalid_argument("[max] Cannot max reduce zero size array."); throw std::invalid_argument("[max] Cannot max reduce zero size array.");
} }
auto [out_shape, sorted_axes, squeezed_shape, is_noop] = auto [out_shape, sorted_axes, is_noop] =
compute_reduce_shape(axes, a.shape()); compute_reduce_shape(axes, a.shape());
auto out = (is_noop) auto out = (is_noop)
? a ? a
@ -1797,7 +1791,7 @@ array max(
std::make_shared<Reduce>(to_stream(s), Reduce::Max, sorted_axes), std::make_shared<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
{a}); {a});
if (!keepdims) { if (!keepdims) {
out = reshape(out, std::move(squeezed_shape), s); out = squeeze(out, sorted_axes, s);
} }
return out; return out;
} }
@ -1827,7 +1821,7 @@ array min(
if (axes.empty()) { if (axes.empty()) {
return a; return a;
} }
auto [out_shape, sorted_axes, squeezed_shape, is_noop] = auto [out_shape, sorted_axes, is_noop] =
compute_reduce_shape(axes, a.shape()); compute_reduce_shape(axes, a.shape());
auto out = (is_noop) auto out = (is_noop)
? a ? a
@ -1837,7 +1831,7 @@ array min(
std::make_shared<Reduce>(to_stream(s), Reduce::Min, sorted_axes), std::make_shared<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
{a}); {a});
if (!keepdims) { if (!keepdims) {
out = reshape(out, std::move(squeezed_shape), s); out = squeeze(out, sorted_axes, s);
} }
return out; return out;
} }
@ -1870,7 +1864,7 @@ array argmin(
throw std::invalid_argument( throw std::invalid_argument(
"[argmin] Cannot argmin reduce zero size array."); "[argmin] Cannot argmin reduce zero size array.");
} }
auto [out_shape, sorted_axes, squeezed_shape, is_noop] = auto [out_shape, sorted_axes, is_noop] =
compute_reduce_shape({axis}, a.shape()); compute_reduce_shape({axis}, a.shape());
auto out = (is_noop) auto out = (is_noop)
? zeros(out_shape, uint32, s) ? zeros(out_shape, uint32, s)
@ -1881,7 +1875,7 @@ array argmin(
to_stream(s), ArgReduce::ArgMin, sorted_axes[0]), to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),
{a}); {a});
if (!keepdims) { if (!keepdims) {
out = reshape(out, std::move(squeezed_shape), s); out = squeeze(out, sorted_axes[0], s);
} }
return out; return out;
} }
@ -1906,7 +1900,7 @@ array argmax(
throw std::invalid_argument( throw std::invalid_argument(
"[argmax] Cannot argmax reduce zero size array."); "[argmax] Cannot argmax reduce zero size array.");
} }
auto [out_shape, sorted_axes, squeezed_shape, is_noop] = auto [out_shape, sorted_axes, is_noop] =
compute_reduce_shape({axis}, a.shape()); compute_reduce_shape({axis}, a.shape());
auto out = (is_noop) auto out = (is_noop)
? zeros(out_shape, uint32, s) ? zeros(out_shape, uint32, s)
@ -1917,7 +1911,7 @@ array argmax(
to_stream(s), ArgReduce::ArgMax, sorted_axes[0]), to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),
{a}); {a});
if (!keepdims) { if (!keepdims) {
out = reshape(out, std::move(squeezed_shape), s); out = squeeze(out, sorted_axes[0], s);
} }
return out; return out;
} }
@ -2544,11 +2538,11 @@ array matmul(
} }
if (a.ndim() == 1) { if (a.ndim() == 1) {
// Insert a singleton dim in the beginning // Insert a singleton dim in the beginning
a = reshape(a, {1, -1}, s); a = expand_dims(a, 0, s);
} }
if (b.ndim() == 1) { if (b.ndim() == 1) {
// Insert a singleton dim at the end // Insert a singleton dim at the end
b = reshape(b, {-1, 1}, s); b = expand_dims(b, 1, s);
} }
if (a.shape(-1) != b.shape(-2)) { if (a.shape(-1) != b.shape(-2)) {
std::ostringstream msg; std::ostringstream msg;
@ -2608,17 +2602,21 @@ array matmul(
auto out_shape = a.shape(); auto out_shape = a.shape();
out_shape.back() = b.shape(-1); out_shape.back() = b.shape(-1);
auto p = std::make_shared<Matmul>(to_stream(s)); auto out = array(
std::move(out_shape),
out_type,
std::make_shared<Matmul>(to_stream(s)),
{a, b});
// Remove the possibly inserted singleton dimensions // Remove the possibly inserted singleton dimensions
if (in_a.ndim() == 1 || in_b.ndim() == 1) { std::vector<int> axes;
auto out = array(out_shape, out_type, std::move(p), {a, b}); if (in_a.ndim() == 1) {
out_shape.erase( axes.push_back(out.ndim() - 2);
out_shape.end() - ((in_a.ndim() == 1) ? 2 : 1),
out_shape.end() - ((in_b.ndim() == 1) ? 0 : 1));
return reshape(out, std::move(out_shape), s);
} }
return array(std::move(out_shape), out_type, std::move(p), {a, b}); if (in_b.ndim() == 1) {
axes.push_back(out.ndim() - 1);
}
return axes.empty() ? out : squeeze(out, axes, s);
} }
array gather( array gather(
@ -2658,15 +2656,6 @@ array gather(
<< " for array with " << a.ndim() << " dimensions."; << " for array with " << a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
for (int i = 0; i < a.ndim(); ++i) {
if (slice_sizes[i] < 0 || slice_sizes[i] > a.shape(i)) {
std::ostringstream msg;
msg << "[gather] Slice sizes must be in [0, a.shape(i)]. Got "
<< slice_sizes << " for array with shape " << a.shape() << ".";
throw std::invalid_argument(msg.str());
}
}
// Promote indices to the same type // Promote indices to the same type
auto dtype = result_type(indices); auto dtype = result_type(indices);
if (issubdtype(dtype, inexact)) { if (issubdtype(dtype, inexact)) {
@ -2680,6 +2669,29 @@ array gather(
idx = astype(idx, dtype, s); idx = astype(idx, dtype, s);
} }
if (a.size() == 0) {
// Empty input, either the total slice size is 0 or the indices are empty
auto total_slice = std::accumulate(
slice_sizes.begin(), slice_sizes.end(), 1, std::multiplies<int64_t>{});
auto idx_size = !inputs.empty() ? inputs[0].size() : 1;
if (idx_size != 0 && total_slice != 0) {
std::ostringstream msg;
msg << "[gather] If the input is empty, either the indices must be"
<< " empty or the total slice size must be 0.";
throw std::invalid_argument(msg.str());
}
} else {
// Non-empty input, check slice sizes are valid
for (int i = 0; i < a.ndim(); ++i) {
if (slice_sizes[i] < 0 || slice_sizes[i] > a.shape(i)) {
std::ostringstream msg;
msg << "[gather] Slice sizes must be in [0, a.shape(i)]. Got "
<< slice_sizes << " for array with shape " << a.shape() << ".";
throw std::invalid_argument(msg.str());
}
}
}
Shape out_shape; Shape out_shape;
if (!inputs.empty()) { if (!inputs.empty()) {
out_shape = inputs[0].shape(); out_shape = inputs[0].shape();
@ -2688,9 +2700,10 @@ array gather(
inputs.insert(inputs.begin(), a); inputs.insert(inputs.begin(), a);
return array( return array(
out_shape, std::move(out_shape),
a.dtype(), a.dtype(),
std::make_shared<Gather>(to_stream(s), axes, slice_sizes), std::make_shared<Gather>(
to_stream(s), std::move(axes), std::move(slice_sizes)),
inputs); inputs);
} }
@ -2719,7 +2732,7 @@ array take(
// Make slice sizes to pass to gather // Make slice sizes to pass to gather
Shape slice_sizes = a.shape(); Shape slice_sizes = a.shape();
slice_sizes[axis] = indices.size() > 0 ? 1 : 0; slice_sizes[axis] = 1;
auto out = gather(a, indices, axis, slice_sizes, s); auto out = gather(a, indices, axis, slice_sizes, s);
@ -2736,9 +2749,7 @@ array take(
} }
// Squeeze the axis we take over // Squeeze the axis we take over
auto out_shape = out.shape(); return squeeze(out, indices.ndim() + axis, s);
out_shape.erase(out_shape.begin() + indices.ndim() + axis);
return reshape(out, std::move(out_shape), s);
} }
array take(const array& a, const array& indices, StreamOrDevice s /* = {} */) { array take(const array& a, const array& indices, StreamOrDevice s /* = {} */) {
@ -2811,12 +2822,14 @@ array take_along_axis(
} }
std::vector<int> dims(a.ndim()); std::vector<int> dims(a.ndim());
std::iota(dims.begin(), dims.end(), 0); std::iota(dims.begin(), dims.end(), 0);
Shape slice_sizes(a.ndim(), a.size() > 0); Shape slice_sizes(a.ndim(), 1);
auto out = gather(a, nd_indices, dims, slice_sizes, s); auto out = gather(a, nd_indices, dims, slice_sizes, s);
// Squeeze out the slice shape // Squeeze out the slice shape
Shape out_shape(out.shape().begin(), out.shape().begin() + a.ndim()); for (auto& d : dims) {
return reshape(out, std::move(out_shape), s); d += a.ndim();
}
return squeeze(out, dims, s);
} }
array put_along_axis( array put_along_axis(
@ -3935,17 +3948,20 @@ array addmm(
} }
auto out = array( auto out = array(
out_shape, std::move(out_shape),
out_type, out_type,
std::make_shared<AddMM>(to_stream(s), alpha, beta), std::make_shared<AddMM>(to_stream(s), alpha, beta),
{a, b, c}); {a, b, c});
// Remove the possibly inserted singleton dimensions // Remove the possibly inserted singleton dimensions
if (in_a_ndim == 1 || in_b_ndim == 1) { std::vector<int> axes;
out = reshape(out, out_shape_adjusted, s); if (in_a_ndim == 1) {
axes.push_back(out.ndim() - 2);
} }
if (in_b_ndim == 1) {
return out; axes.push_back(out.ndim() - 1);
}
return axes.empty() ? out : squeeze(out, axes, s);
} }
/** Compute matrix product with tile-level masking */ /** Compute matrix product with tile-level masking */
@ -3986,11 +4002,11 @@ array block_masked_mm(
if (a.ndim() == 1) { if (a.ndim() == 1) {
// Insert a singleton dim in the beginning // Insert a singleton dim in the beginning
a = reshape(a, {1, -1}, s); a = expand_dims(a, 0, s);
} }
if (b.ndim() == 1) { if (b.ndim() == 1) {
// Insert a singleton dim at the end // Insert a singleton dim at the end
b = reshape(b, {-1, 1}, s); b = expand_dims(b, 1, s);
} }
if (a.shape(-1) != b.shape(-2)) { if (a.shape(-1) != b.shape(-2)) {
@ -4110,20 +4126,19 @@ array block_masked_mm(
// Caculate array // Caculate array
auto out = array( auto out = array(
out_shape, std::move(out_shape),
out_type, out_type,
std::make_shared<BlockMaskedMM>(to_stream(s), block_size), std::make_shared<BlockMaskedMM>(to_stream(s), block_size),
std::move(inputs)); std::move(inputs));
// Remove the possibly inserted singleton dimensions // Remove the possibly inserted singleton dimensions
if (in_a_ndim == 1 || in_b_ndim == 1) { std::vector<int> axes;
out_shape.erase( if (in_a_ndim == 1) {
out_shape.end() - ((in_a_ndim == 1) ? 2 : 1), axes.push_back(out.ndim() - 2);
out_shape.end() - ((in_b_ndim == 1) ? 0 : 1));
out = reshape(out, out_shape, s);
} }
if (in_b_ndim == 1) {
return out; axes.push_back(out.ndim() - 1);
}
return axes.empty() ? out : squeeze(out, axes, s);
} }
/** Compute matrix product with matrix-level gather */ /** Compute matrix product with matrix-level gather */
@ -4150,11 +4165,11 @@ array gather_mm(
if (a.ndim() == 1) { if (a.ndim() == 1) {
// Insert a singleton dim in the beginning // Insert a singleton dim in the beginning
a = reshape(a, {1, -1}, s); a = expand_dims(a, 0, s);
} }
if (b.ndim() == 1) { if (b.ndim() == 1) {
// Insert a singleton dim at the end // Insert a singleton dim at the end
b = reshape(b, {-1, 1}, s); b = expand_dims(b, 1, s);
} }
if (a.shape(-1) != b.shape(-2)) { if (a.shape(-1) != b.shape(-2)) {
@ -4212,20 +4227,20 @@ array gather_mm(
// Caculate array // Caculate array
auto out = array( auto out = array(
out_shape, std::move(out_shape),
out_type, out_type,
std::make_shared<GatherMM>(to_stream(s)), std::make_shared<GatherMM>(to_stream(s)),
{a, b, lhs_indices, rhs_indices}); {a, b, lhs_indices, rhs_indices});
// Remove the possibly inserted singleton dimensions // Remove the possibly inserted singleton dimensions
if (in_a_ndim == 1 || in_b_ndim == 1) { std::vector<int> axes;
out_shape.erase( if (in_a_ndim == 1) {
out_shape.end() - ((in_a_ndim == 1) ? 2 : 1), axes.push_back(out.ndim() - 2);
out_shape.end() - ((in_b_ndim == 1) ? 0 : 1));
out = reshape(out, out_shape, s);
} }
if (in_b_ndim == 1) {
return out; axes.push_back(out.ndim() - 1);
}
return axes.empty() ? out : squeeze(out, axes, s);
} }
array diagonal( array diagonal(

View File

@ -1602,6 +1602,55 @@ std::pair<std::vector<array>, std::vector<int>> Expm1::vmap(
return {{expm1(inputs[0], stream())}, axes}; return {{expm1(inputs[0], stream())}, axes};
} }
std::vector<array> ExpandDims::vjp(
const std::vector<array>&,
const std::vector<array>& cotangents,
const std::vector<int>&,
const std::vector<array>&) {
return {squeeze(cotangents[0], axes_, stream())};
}
std::vector<array> ExpandDims::jvp(
const std::vector<array>&,
const std::vector<array>& tangents,
const std::vector<int>&) {
return {expand_dims(tangents[0], axes_, stream())};
}
std::pair<std::vector<array>, std::vector<int>> ExpandDims::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto ax = axes[0];
auto expand_axes = axes_;
for (auto& s : expand_axes) {
if (s >= axes[0]) {
s++;
} else {
ax++;
}
}
return {{expand_dims(inputs[0], std::move(expand_axes), stream())}, {ax}};
}
bool ExpandDims::is_equivalent(const Primitive& other) const {
const ExpandDims& a_other = static_cast<const ExpandDims&>(other);
return (axes_ == a_other.axes_);
}
Shape ExpandDims::output_shape(
const array& input,
const std::vector<int>& axes) {
auto shape = input.shape();
for (auto ax : axes) {
shape.insert(shape.begin() + ax, 1);
}
return shape;
}
std::vector<Shape> ExpandDims::output_shapes(const std::vector<array>& inputs) {
return {ExpandDims::output_shape(inputs[0], axes_)};
}
bool FFT::is_equivalent(const Primitive& other) const { bool FFT::is_equivalent(const Primitive& other) const {
const FFT& r_other = static_cast<const FFT&>(other); const FFT& r_other = static_cast<const FFT&>(other);
return axes_ == r_other.axes_ && inverse_ == r_other.inverse_ && return axes_ == r_other.axes_ && inverse_ == r_other.inverse_ &&
@ -1846,7 +1895,7 @@ bool Gather::is_equivalent(const Primitive& other) const {
std::vector<Shape> Gather::output_shapes(const std::vector<array>& inputs) { std::vector<Shape> Gather::output_shapes(const std::vector<array>& inputs) {
Shape out_shape; Shape out_shape;
if (inputs.size() > 1) { if (inputs.size() > 1) {
out_shape = inputs[0].shape(); out_shape = inputs[1].shape();
} }
out_shape.insert(out_shape.end(), slice_sizes_.begin(), slice_sizes_.end()); out_shape.insert(out_shape.end(), slice_sizes_.begin(), slice_sizes_.end());
return {std::move(out_shape)}; return {std::move(out_shape)};
@ -3847,6 +3896,57 @@ std::pair<std::vector<array>, std::vector<int>> Subtract::vmap(
return {{subtract(a, b, stream())}, {to_ax}}; return {{subtract(a, b, stream())}, {to_ax}};
} }
std::vector<array> Squeeze::vjp(
const std::vector<array>&,
const std::vector<array>& cotangents,
const std::vector<int>&,
const std::vector<array>&) {
return {expand_dims(cotangents[0], axes_, stream())};
}
std::vector<array> Squeeze::jvp(
const std::vector<array>&,
const std::vector<array>& tangents,
const std::vector<int>&) {
return {squeeze(tangents[0], axes_, stream())};
}
std::pair<std::vector<array>, std::vector<int>> Squeeze::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto ax = axes[0];
auto squeeze_axes = axes_;
for (auto& s : squeeze_axes) {
if (s >= axes[0]) {
s++;
} else {
ax--;
}
}
return {{squeeze(inputs[0], std::move(squeeze_axes), stream())}, {ax}};
}
bool Squeeze::is_equivalent(const Primitive& other) const {
const Squeeze& a_other = static_cast<const Squeeze&>(other);
return (axes_ == a_other.axes_);
}
Shape Squeeze::output_shape(const array& input, const std::vector<int>& axes) {
Shape shape;
for (int i = 0, j = 0; i < input.ndim(); ++i) {
if (j < axes.size() && i == axes[j]) {
j++;
} else {
shape.push_back(input.shape(i));
}
}
return shape;
}
std::vector<Shape> Squeeze::output_shapes(const std::vector<array>& inputs) {
return {Squeeze::output_shape(inputs[0], axes_)};
}
std::vector<array> Tan::vjp( std::vector<array> Tan::vjp(
const std::vector<array>& primals, const std::vector<array>& primals,
const std::vector<array>& cotangents, const std::vector<array>& cotangents,

View File

@ -983,6 +983,28 @@ class Expm1 : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
}; };
class ExpandDims : public UnaryPrimitive {
public:
explicit ExpandDims(Stream stream, std::vector<int> axes)
: UnaryPrimitive(stream), axes_(std::move(axes)) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(ExpandDims)
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
bool is_equivalent(const Primitive& other) const override;
static Shape output_shape(const array& input, const std::vector<int>& axes);
private:
void eval(const std::vector<array>& inputs, array& out);
std::vector<int> axes_;
};
class FFT : public UnaryPrimitive { class FFT : public UnaryPrimitive {
public: public:
explicit FFT( explicit FFT(
@ -1046,9 +1068,11 @@ class Gather : public UnaryPrimitive {
public: public:
explicit Gather( explicit Gather(
Stream stream, Stream stream,
const std::vector<int>& axes, std::vector<int> axes,
const std::vector<int>& slice_sizes) std::vector<int> slice_sizes)
: UnaryPrimitive(stream), axes_(axes), slice_sizes_(slice_sizes) {} : UnaryPrimitive(stream),
axes_(std::move(axes)),
slice_sizes_(std::move(slice_sizes)) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -2057,6 +2081,28 @@ class Subtract : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
}; };
class Squeeze : public UnaryPrimitive {
public:
explicit Squeeze(Stream stream, std::vector<int> axes)
: UnaryPrimitive(stream), axes_(std::move(axes)) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Squeeze)
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
bool is_equivalent(const Primitive& other) const override;
static Shape output_shape(const array& input, const std::vector<int>& axes);
private:
void eval(const std::vector<array>& inputs, array& out);
std::vector<int> axes_;
};
class Tan : public UnaryPrimitive { class Tan : public UnaryPrimitive {
public: public:
explicit Tan(Stream stream) : UnaryPrimitive(stream) {} explicit Tan(Stream stream) : UnaryPrimitive(stream) {}

View File

@ -144,23 +144,23 @@ array mlx_gather_nd(
int slice_index = 0; int slice_index = 0;
for (int i = 0; i < gather_indices.size(); i++) { for (int i = 0; i < gather_indices.size(); i++) {
if (is_slice[i]) { if (is_slice[i]) {
std::vector<int> index_shape(max_dims + num_slices, 1); Shape index_shape(max_dims + num_slices, 1);
index_shape[max_dims + slice_index] = gather_indices[i].shape(0); index_shape[max_dims + slice_index] = gather_indices[i].shape(0);
gather_indices[i] = reshape(gather_indices[i], index_shape); gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
slice_index++; slice_index++;
} else { } else {
std::vector<int> index_shape = gather_indices[i].shape(); auto index_shape = gather_indices[i].shape();
index_shape.insert(index_shape.end(), num_slices, 1); index_shape.insert(index_shape.end(), num_slices, 1);
gather_indices[i] = reshape(gather_indices[i], index_shape); gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
} }
} }
} else { } else {
// reshape them so that the int/array indices are last // reshape them so that the int/array indices are last
for (int i = 0; i < gather_indices.size(); i++) { for (int i = 0; i < gather_indices.size(); i++) {
if (i < num_slices) { if (i < num_slices) {
std::vector<int> index_shape(max_dims + num_slices, 1); Shape index_shape(max_dims + num_slices, 1);
index_shape[i] = gather_indices[i].shape(0); index_shape[i] = gather_indices[i].shape(0);
gather_indices[i] = reshape(gather_indices[i], index_shape); gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
} }
} }
} }
@ -172,19 +172,11 @@ array mlx_gather_nd(
std::fill(slice_sizes.begin(), slice_sizes.begin() + indices.size(), 1); std::fill(slice_sizes.begin(), slice_sizes.begin() + indices.size(), 1);
src = gather(src, gather_indices, axes, slice_sizes); src = gather(src, gather_indices, axes, slice_sizes);
// Squeeze the dims // Squeeze the array index dims
std::vector<int> out_shape; for (auto& ax : axes) {
out_shape.insert( ax += max_dims + num_slices;
out_shape.end(), }
src.shape().begin(), return squeeze(src, axes);
src.shape().begin() + max_dims + num_slices);
out_shape.insert(
out_shape.end(),
src.shape().begin() + max_dims + num_slices + indices.size(),
src.shape().end());
src = reshape(src, out_shape);
return src;
} }
auto mlx_expand_ellipsis( auto mlx_expand_ellipsis(

View File

@ -392,27 +392,6 @@ class TestCompile(mlx_tests.MLXTestCase):
out = fun(x, y=y, z=z) out = fun(x, y=y, z=z)
self.assertEqual(out.item(), 6) self.assertEqual(out.item(), 6)
def test_shapeless_compile(self):
y = 1
@partial(mx.compile, shapeless=True)
def fun(x):
return x + y
x = mx.array([1, 2])
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3])))
# The function is not recompiled, so the change
# to y should not be reflected in the output
y = 2
x = mx.array([1, 2, 3])
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3, 4])))
# Type change recompiles
x = mx.array([1.0, 2.0, 3.0])
self.assertTrue(mx.array_equal(fun(x), mx.array([3.0, 4.0, 5.0])))
fun(x, y=y, z=z)
def test_shapeless_compile(self): def test_shapeless_compile(self):
y = 1 y = 1
@ -477,6 +456,12 @@ class TestCompile(mlx_tests.MLXTestCase):
mx.eval(cfun(x1)) mx.eval(cfun(x1))
self.assertTrue(mx.array_equal(fun(x2), cfun(x2))) self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))
def fun(x):
return x * x.sum(-1, keepdims=False)
cfun = mx.compile(fun, shapeless=True)
self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))
def test_compile_with_constant(self): def test_compile_with_constant(self):
# Test float # Test float
@partial(mx.compile) @partial(mx.compile)
@ -809,6 +794,13 @@ class TestCompile(mlx_tests.MLXTestCase):
out = fun(*inputs) out = fun(*inputs)
self.assertTrue(mx.allclose(out, mx.full((2, 2), 20))) self.assertTrue(mx.allclose(out, mx.full((2, 2), 20)))
def test_shapeless_compile_matmul(self):
a = mx.array([0.0, 1.0, 2.0])
b = mx.array([0.0, 1.0, 2.0])
fun = mx.compile(lambda a, b: a @ b, shapeless=True)
self.assertTrue(mx.allclose(fun(a, b), a @ b))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -1835,6 +1835,9 @@ TEST_CASE("test broadcast") {
} }
TEST_CASE("test gather") { TEST_CASE("test gather") {
// Empty input, non-empty indices/slice
CHECK_THROWS(gather(array({}), array({1}), 0, {1}));
// More indices than dimensions // More indices than dimensions
CHECK_THROWS(gather(array(0), array({1}), 0, {1})); CHECK_THROWS(gather(array(0), array({1}), 0, {1}));