ExpandDims primitive (#1687)

* add squeeze primitive

* simplify squeeze, use in gather

* fix

* fix

* fix

* fix

* fix no cpu

* use squeeze in matmul and friends

* expand dims primitive

* comment
This commit is contained in:
Awni Hannun
2024-12-10 16:39:07 -08:00
committed by GitHub
parent 310ad8d9db
commit f76a49e555
13 changed files with 373 additions and 184 deletions

View File

@@ -20,7 +20,7 @@ namespace mlx::core {
namespace {
std::tuple<Shape, std::vector<int>, Shape, bool> compute_reduce_shape(
std::tuple<Shape, std::vector<int>, bool> compute_reduce_shape(
const std::vector<int>& axes,
const Shape& shape) {
bool is_noop = true;
@@ -40,18 +40,16 @@ std::tuple<Shape, std::vector<int>, Shape, bool> compute_reduce_shape(
throw std::invalid_argument("Duplicate axes detected in reduction.");
}
Shape out_shape;
Shape squeezed_shape;
for (int i = 0; i < ndim; ++i) {
if (axes_set.count(i) == 0) {
out_shape.push_back(shape[i]);
squeezed_shape.push_back(shape[i]);
} else {
out_shape.push_back(1);
}
is_noop &= (out_shape.back() == shape[i]);
}
std::vector<int> sorted_axes(axes_set.begin(), axes_set.end());
return {out_shape, sorted_axes, squeezed_shape, is_noop};
return {out_shape, sorted_axes, is_noop};
}
Dtype at_least_float(const Dtype& d) {
@@ -460,54 +458,51 @@ array hadamard_transform(
{astype(a, dtype, s)});
}
array squeeze_impl(
const array& a,
std::vector<int> axes,
StreamOrDevice s /* = {} */) {
for (auto& ax : axes) {
auto new_ax = ax < 0 ? ax + a.ndim() : ax;
if (new_ax < 0 || new_ax >= a.ndim()) {
std::ostringstream msg;
msg << "[squeeze] Invalid axes " << ax << " for array with " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
if (a.shape(new_ax) != 1) {
std::ostringstream msg;
msg << "[squeeze] Cannot squeeze axis " << ax << " with size "
<< a.shape(ax) << " which is not equal to 1.";
throw std::invalid_argument(msg.str());
}
ax = new_ax;
}
auto shape = Squeeze::output_shape(a, axes);
return array(
std::move(shape),
a.dtype(),
std::make_shared<Squeeze>(to_stream(s), std::move(axes)),
{a});
}
array squeeze(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s /* = {} */) {
std::set<int> unique_axes;
for (auto ax : axes) {
ax = ax < 0 ? ax + a.ndim() : ax;
if (ax < 0 || ax >= a.ndim()) {
std::ostringstream msg;
msg << "[squeeze] Invalid axes " << ax << " for array with " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
if (a.shape(ax) != 1) {
std::ostringstream msg;
msg << "[squeeze] Cannot squeeze axis " << ax << " with size "
<< a.shape(ax) << " which is not equal to 1.";
throw std::invalid_argument(msg.str());
}
unique_axes.insert(ax);
unique_axes.insert(ax < 0 ? ax + a.ndim() : ax);
}
if (unique_axes.size() != axes.size()) {
throw std::invalid_argument("[squeeze] Received duplicate axes.");
}
std::vector<int> sorted_axes(unique_axes.begin(), unique_axes.end());
Shape shape;
for (int i = 0, j = 0; i < a.ndim(); ++i) {
if (j < sorted_axes.size() && i == sorted_axes[j]) {
j++;
} else {
shape.push_back(a.shape(i));
}
}
return reshape(a, std::move(shape), s);
return squeeze_impl(a, std::move(sorted_axes), s);
}
array squeeze(const array& a, int axis, StreamOrDevice s /* = {} */) {
int ax = axis < 0 ? axis + a.ndim() : axis;
if (ax < 0 || ax >= a.ndim()) {
std::ostringstream msg;
msg << "[squeeze] Invalid axis " << axis << " for array with " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
auto shape = a.shape();
shape.erase(shape.begin() + ax);
return reshape(a, std::move(shape), s);
return squeeze_impl(a, {axis}, s);
}
array squeeze(const array& a, StreamOrDevice s /* = {} */) {
@@ -517,21 +512,34 @@ array squeeze(const array& a, StreamOrDevice s /* = {} */) {
axes.push_back(i);
}
}
return squeeze(a, axes, s);
return squeeze_impl(a, std::move(axes), s);
}
array expand_dims_impl(
const array& a,
std::vector<int> axes,
StreamOrDevice s /* = {} */) {
auto out_ndim = a.ndim() + axes.size();
for (auto& ax : axes) {
auto new_ax = ax < 0 ? ax + out_ndim : ax;
if (new_ax < 0 || new_ax >= out_ndim) {
std::ostringstream msg;
msg << "[expand_dims] Invalid axis " << ax << " for output array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
ax = new_ax;
}
auto shape = ExpandDims::output_shape(a, axes);
return array(
std::move(shape),
a.dtype(),
std::make_shared<ExpandDims>(to_stream(s), std::move(axes)),
{a});
}
array expand_dims(const array& a, int axis, StreamOrDevice s /* = {} */) {
int out_dim = a.ndim() + 1;
int ax = axis < 0 ? axis + out_dim : axis;
if (ax < 0 || ax >= out_dim) {
std::ostringstream msg;
msg << "[expand_dims] Invalid axis " << axis << " for output array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
auto shape = a.shape();
shape.insert(shape.begin() + ax, 1);
return reshape(a, std::move(shape), s);
return expand_dims_impl(a, {axis}, s);
}
array expand_dims(
@@ -544,31 +552,17 @@ array expand_dims(
throw std::invalid_argument("[expand_dims] Received duplicate axes.");
}
}
int out_ndim = axes.size() + a.ndim();
std::vector<int> canonical_axes = axes;
for (auto& ax : canonical_axes) {
ax = ax < 0 ? ax + out_ndim : ax;
if (ax < 0 || ax >= out_ndim) {
std::ostringstream msg;
msg << "[expand_dims] Invalid axis " << ax << " for output array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
}
// Check for repeats again
std::set<int> unique_axes(canonical_axes.begin(), canonical_axes.end());
auto out_ndim = a.ndim() + axes.size();
std::set<int> unique_axes;
for (auto ax : axes) {
unique_axes.insert(ax < 0 ? ax + out_ndim : ax);
}
if (unique_axes.size() != axes.size()) {
throw std::invalid_argument("[expand_dims] Received duplicate axes.");
}
std::vector<int> sorted_axes(unique_axes.begin(), unique_axes.end());
auto out_shape = a.shape();
for (int i = 0; i < sorted_axes.size(); ++i) {
out_shape.insert(out_shape.begin() + sorted_axes[i], 1);
}
return reshape(a, std::move(out_shape), s);
return expand_dims_impl(a, std::move(sorted_axes), s);
}
// Slice helper
@@ -1519,7 +1513,7 @@ array all(
const std::vector<int>& axes,
bool keepdims /* = false */,
StreamOrDevice s /* = {}*/) {
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
auto [out_shape, sorted_axes, is_noop] =
compute_reduce_shape(axes, a.shape());
auto out = (is_noop)
? astype(a, bool_, s)
@@ -1529,7 +1523,7 @@ array all(
std::make_shared<Reduce>(to_stream(s), Reduce::And, sorted_axes),
{a});
if (!keepdims) {
out = reshape(out, std::move(squeezed_shape), s);
out = squeeze(out, sorted_axes, s);
}
return out;
}
@@ -1553,7 +1547,7 @@ array any(
const std::vector<int>& axes,
bool keepdims /* = false */,
StreamOrDevice s /* = {}*/) {
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
auto [out_shape, sorted_axes, is_noop] =
compute_reduce_shape(axes, a.shape());
auto out = (is_noop)
? astype(a, bool_, s)
@@ -1563,7 +1557,7 @@ array any(
std::make_shared<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
{a});
if (!keepdims) {
out = reshape(out, std::move(squeezed_shape), s);
out = squeeze(out, sorted_axes, s);
}
return out;
}
@@ -1590,7 +1584,7 @@ array sum(
if (axes.empty()) {
return a;
}
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
auto [out_shape, sorted_axes, is_noop] =
compute_reduce_shape(axes, a.shape());
Dtype out_type = a.dtype();
if (issubdtype(a.dtype(), signedinteger)) {
@@ -1608,7 +1602,7 @@ array sum(
std::make_shared<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
{a});
if (!keepdims) {
out = reshape(out, std::move(squeezed_shape), s);
out = squeeze(out, sorted_axes, s);
}
return out;
}
@@ -1742,7 +1736,7 @@ array prod(
if (axes.empty()) {
return a;
}
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
auto [out_shape, sorted_axes, is_noop] =
compute_reduce_shape(axes, a.shape());
Dtype out_type = a.dtype();
if (issubdtype(a.dtype(), signedinteger)) {
@@ -1760,7 +1754,7 @@ array prod(
std::make_shared<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
{a});
if (!keepdims) {
out = reshape(out, std::move(squeezed_shape), s);
out = squeeze(out, sorted_axes, s);
}
return out;
}
@@ -1787,7 +1781,7 @@ array max(
if (a.size() == 0) {
throw std::invalid_argument("[max] Cannot max reduce zero size array.");
}
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
auto [out_shape, sorted_axes, is_noop] =
compute_reduce_shape(axes, a.shape());
auto out = (is_noop)
? a
@@ -1797,7 +1791,7 @@ array max(
std::make_shared<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
{a});
if (!keepdims) {
out = reshape(out, std::move(squeezed_shape), s);
out = squeeze(out, sorted_axes, s);
}
return out;
}
@@ -1827,7 +1821,7 @@ array min(
if (axes.empty()) {
return a;
}
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
auto [out_shape, sorted_axes, is_noop] =
compute_reduce_shape(axes, a.shape());
auto out = (is_noop)
? a
@@ -1837,7 +1831,7 @@ array min(
std::make_shared<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
{a});
if (!keepdims) {
out = reshape(out, std::move(squeezed_shape), s);
out = squeeze(out, sorted_axes, s);
}
return out;
}
@@ -1870,7 +1864,7 @@ array argmin(
throw std::invalid_argument(
"[argmin] Cannot argmin reduce zero size array.");
}
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
auto [out_shape, sorted_axes, is_noop] =
compute_reduce_shape({axis}, a.shape());
auto out = (is_noop)
? zeros(out_shape, uint32, s)
@@ -1881,7 +1875,7 @@ array argmin(
to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),
{a});
if (!keepdims) {
out = reshape(out, std::move(squeezed_shape), s);
out = squeeze(out, sorted_axes[0], s);
}
return out;
}
@@ -1906,7 +1900,7 @@ array argmax(
throw std::invalid_argument(
"[argmax] Cannot argmax reduce zero size array.");
}
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
auto [out_shape, sorted_axes, is_noop] =
compute_reduce_shape({axis}, a.shape());
auto out = (is_noop)
? zeros(out_shape, uint32, s)
@@ -1917,7 +1911,7 @@ array argmax(
to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),
{a});
if (!keepdims) {
out = reshape(out, std::move(squeezed_shape), s);
out = squeeze(out, sorted_axes[0], s);
}
return out;
}
@@ -2544,11 +2538,11 @@ array matmul(
}
if (a.ndim() == 1) {
// Insert a singleton dim in the beginning
a = reshape(a, {1, -1}, s);
a = expand_dims(a, 0, s);
}
if (b.ndim() == 1) {
// Insert a singleton dim at the end
b = reshape(b, {-1, 1}, s);
b = expand_dims(b, 1, s);
}
if (a.shape(-1) != b.shape(-2)) {
std::ostringstream msg;
@@ -2608,17 +2602,21 @@ array matmul(
auto out_shape = a.shape();
out_shape.back() = b.shape(-1);
auto p = std::make_shared<Matmul>(to_stream(s));
auto out = array(
std::move(out_shape),
out_type,
std::make_shared<Matmul>(to_stream(s)),
{a, b});
// Remove the possibly inserted singleton dimensions
if (in_a.ndim() == 1 || in_b.ndim() == 1) {
auto out = array(out_shape, out_type, std::move(p), {a, b});
out_shape.erase(
out_shape.end() - ((in_a.ndim() == 1) ? 2 : 1),
out_shape.end() - ((in_b.ndim() == 1) ? 0 : 1));
return reshape(out, std::move(out_shape), s);
std::vector<int> axes;
if (in_a.ndim() == 1) {
axes.push_back(out.ndim() - 2);
}
return array(std::move(out_shape), out_type, std::move(p), {a, b});
if (in_b.ndim() == 1) {
axes.push_back(out.ndim() - 1);
}
return axes.empty() ? out : squeeze(out, axes, s);
}
array gather(
@@ -2658,15 +2656,6 @@ array gather(
<< " for array with " << a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
for (int i = 0; i < a.ndim(); ++i) {
if (slice_sizes[i] < 0 || slice_sizes[i] > a.shape(i)) {
std::ostringstream msg;
msg << "[gather] Slice sizes must be in [0, a.shape(i)]. Got "
<< slice_sizes << " for array with shape " << a.shape() << ".";
throw std::invalid_argument(msg.str());
}
}
// Promote indices to the same type
auto dtype = result_type(indices);
if (issubdtype(dtype, inexact)) {
@@ -2680,6 +2669,29 @@ array gather(
idx = astype(idx, dtype, s);
}
if (a.size() == 0) {
// Empty input, either the total slice size is 0 or the indices are empty
auto total_slice = std::accumulate(
slice_sizes.begin(), slice_sizes.end(), 1, std::multiplies<int64_t>{});
auto idx_size = !inputs.empty() ? inputs[0].size() : 1;
if (idx_size != 0 && total_slice != 0) {
std::ostringstream msg;
msg << "[gather] If the input is empty, either the indices must be"
<< " empty or the total slice size must be 0.";
throw std::invalid_argument(msg.str());
}
} else {
// Non-empty input, check slice sizes are valid
for (int i = 0; i < a.ndim(); ++i) {
if (slice_sizes[i] < 0 || slice_sizes[i] > a.shape(i)) {
std::ostringstream msg;
msg << "[gather] Slice sizes must be in [0, a.shape(i)]. Got "
<< slice_sizes << " for array with shape " << a.shape() << ".";
throw std::invalid_argument(msg.str());
}
}
}
Shape out_shape;
if (!inputs.empty()) {
out_shape = inputs[0].shape();
@@ -2688,9 +2700,10 @@ array gather(
inputs.insert(inputs.begin(), a);
return array(
out_shape,
std::move(out_shape),
a.dtype(),
std::make_shared<Gather>(to_stream(s), axes, slice_sizes),
std::make_shared<Gather>(
to_stream(s), std::move(axes), std::move(slice_sizes)),
inputs);
}
@@ -2719,7 +2732,7 @@ array take(
// Make slice sizes to pass to gather
Shape slice_sizes = a.shape();
slice_sizes[axis] = indices.size() > 0 ? 1 : 0;
slice_sizes[axis] = 1;
auto out = gather(a, indices, axis, slice_sizes, s);
@@ -2736,9 +2749,7 @@ array take(
}
// Squeeze the axis we take over
auto out_shape = out.shape();
out_shape.erase(out_shape.begin() + indices.ndim() + axis);
return reshape(out, std::move(out_shape), s);
return squeeze(out, indices.ndim() + axis, s);
}
array take(const array& a, const array& indices, StreamOrDevice s /* = {} */) {
@@ -2811,12 +2822,14 @@ array take_along_axis(
}
std::vector<int> dims(a.ndim());
std::iota(dims.begin(), dims.end(), 0);
Shape slice_sizes(a.ndim(), a.size() > 0);
Shape slice_sizes(a.ndim(), 1);
auto out = gather(a, nd_indices, dims, slice_sizes, s);
// Squeeze out the slice shape
Shape out_shape(out.shape().begin(), out.shape().begin() + a.ndim());
return reshape(out, std::move(out_shape), s);
for (auto& d : dims) {
d += a.ndim();
}
return squeeze(out, dims, s);
}
array put_along_axis(
@@ -3935,17 +3948,20 @@ array addmm(
}
auto out = array(
out_shape,
std::move(out_shape),
out_type,
std::make_shared<AddMM>(to_stream(s), alpha, beta),
{a, b, c});
// Remove the possibly inserted singleton dimensions
if (in_a_ndim == 1 || in_b_ndim == 1) {
out = reshape(out, out_shape_adjusted, s);
std::vector<int> axes;
if (in_a_ndim == 1) {
axes.push_back(out.ndim() - 2);
}
return out;
if (in_b_ndim == 1) {
axes.push_back(out.ndim() - 1);
}
return axes.empty() ? out : squeeze(out, axes, s);
}
/** Compute matrix product with tile-level masking */
@@ -3986,11 +4002,11 @@ array block_masked_mm(
if (a.ndim() == 1) {
// Insert a singleton dim in the beginning
a = reshape(a, {1, -1}, s);
a = expand_dims(a, 0, s);
}
if (b.ndim() == 1) {
// Insert a singleton dim at the end
b = reshape(b, {-1, 1}, s);
b = expand_dims(b, 1, s);
}
if (a.shape(-1) != b.shape(-2)) {
@@ -4110,20 +4126,19 @@ array block_masked_mm(
// Caculate array
auto out = array(
out_shape,
std::move(out_shape),
out_type,
std::make_shared<BlockMaskedMM>(to_stream(s), block_size),
std::move(inputs));
// Remove the possibly inserted singleton dimensions
if (in_a_ndim == 1 || in_b_ndim == 1) {
out_shape.erase(
out_shape.end() - ((in_a_ndim == 1) ? 2 : 1),
out_shape.end() - ((in_b_ndim == 1) ? 0 : 1));
out = reshape(out, out_shape, s);
std::vector<int> axes;
if (in_a_ndim == 1) {
axes.push_back(out.ndim() - 2);
}
return out;
if (in_b_ndim == 1) {
axes.push_back(out.ndim() - 1);
}
return axes.empty() ? out : squeeze(out, axes, s);
}
/** Compute matrix product with matrix-level gather */
@@ -4150,11 +4165,11 @@ array gather_mm(
if (a.ndim() == 1) {
// Insert a singleton dim in the beginning
a = reshape(a, {1, -1}, s);
a = expand_dims(a, 0, s);
}
if (b.ndim() == 1) {
// Insert a singleton dim at the end
b = reshape(b, {-1, 1}, s);
b = expand_dims(b, 1, s);
}
if (a.shape(-1) != b.shape(-2)) {
@@ -4212,20 +4227,20 @@ array gather_mm(
// Caculate array
auto out = array(
out_shape,
std::move(out_shape),
out_type,
std::make_shared<GatherMM>(to_stream(s)),
{a, b, lhs_indices, rhs_indices});
// Remove the possibly inserted singleton dimensions
if (in_a_ndim == 1 || in_b_ndim == 1) {
out_shape.erase(
out_shape.end() - ((in_a_ndim == 1) ? 2 : 1),
out_shape.end() - ((in_b_ndim == 1) ? 0 : 1));
out = reshape(out, out_shape, s);
std::vector<int> axes;
if (in_a_ndim == 1) {
axes.push_back(out.ndim() - 2);
}
return out;
if (in_b_ndim == 1) {
axes.push_back(out.ndim() - 1);
}
return axes.empty() ? out : squeeze(out, axes, s);
}
array diagonal(