diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 1f80224ad..c4f3658a7 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -43,6 +43,7 @@ DEFAULT(NumberOfElements) DEFAULT(Equal) DEFAULT(Erf) DEFAULT(ErfInv) +DEFAULT(ExpandDims) DEFAULT(FFT) DEFAULT(Floor) DEFAULT(Gather) @@ -76,6 +77,7 @@ DEFAULT(Slice) DEFAULT(SliceUpdate) DEFAULT_MULTI(Split) DEFAULT(Sort) +DEFAULT(Squeeze) DEFAULT(StopGradient) DEFAULT_MULTI(SVD) DEFAULT(Transpose) diff --git a/mlx/backend/common/common.cpp b/mlx/backend/common/common.cpp index 0a677a01b..b0957b54e 100644 --- a/mlx/backend/common/common.cpp +++ b/mlx/backend/common/common.cpp @@ -85,6 +85,16 @@ void Depends::eval( } } +void ExpandDims::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + auto strides = in.strides(); + for (auto ax : axes_) { + strides.insert(strides.begin() + ax, 1); + } + move_or_copy(in, out, strides, in.flags(), in.data_size()); +} + void NumberOfElements::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); out.set_data(allocator::malloc_or_wait(out.nbytes())); @@ -248,6 +258,20 @@ void Split::eval( } } +void Squeeze::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + Strides strides; + for (int i = 0, j = 0; i < in.ndim(); ++i) { + if (j < axes_.size() && i == axes_[j]) { + j++; + } else { + strides.push_back(in.strides(i)); + } + } + move_or_copy(in, out, strides, in.flags(), in.data_size()); +} + void StopGradient::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); move_or_copy(inputs[0], out); diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 3313ac0e1..a79aca81b 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -57,6 +57,7 @@ DEFAULT(Equal) DEFAULT(Erf) DEFAULT(ErfInv) DEFAULT(Exp) +DEFAULT(ExpandDims) DEFAULT(Expm1) DEFAULT(FFT) DEFAULT(Floor) @@ -101,6 +102,7 @@ DEFAULT(Softmax) DEFAULT(Sort) DEFAULT_MULTI(Split) DEFAULT(Square) +DEFAULT(Squeeze) DEFAULT(Sqrt) DEFAULT(StopGradient) DEFAULT(Subtract) diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 5c86d2b84..729001604 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -211,6 +211,10 @@ void Full::eval_gpu(const std::vector& inputs, array& out) { copy_gpu(in, out, ctype); } +void ExpandDims::eval_gpu(const std::vector& inputs, array& out) { + eval(inputs, out); +} + void Load::eval_gpu(const std::vector& inputs, array& out) { out.set_data(allocator::malloc_or_wait(out.nbytes())); auto read_task = [out = out, @@ -381,6 +385,10 @@ void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { /* const Stream& s = */ stream()); } +void Squeeze::eval_gpu(const std::vector& inputs, array& out) { + eval(inputs, out); +} + void StopGradient::eval_gpu(const std::vector& inputs, array& out) { eval(inputs, out); } diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index b2a83b997..dab493345 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -55,6 +55,7 @@ NO_CPU(Equal) NO_CPU(Erf) NO_CPU(ErfInv) NO_CPU(Exp) +NO_CPU(ExpandDims) NO_CPU(Expm1) NO_CPU(FFT) NO_CPU(Floor) @@ -104,6 +105,7 @@ NO_CPU(Softmax) NO_CPU(Sort) NO_CPU_MULTI(Split) NO_CPU(Square) +NO_CPU(Squeeze) NO_CPU(Sqrt) NO_CPU(StopGradient) NO_CPU(Subtract) diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 98c89037e..cfae4ff38 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -55,6 +55,7 @@ NO_GPU(Equal) NO_GPU(Erf) NO_GPU(ErfInv) NO_GPU(Exp) +NO_GPU(ExpandDims) NO_GPU(Expm1) NO_GPU(FFT) NO_GPU(Floor) @@ -104,6 +105,7 @@ NO_GPU(Softmax) NO_GPU(Sort) NO_GPU_MULTI(Split) NO_GPU(Square) +NO_GPU(Squeeze) NO_GPU(Sqrt) NO_GPU(StopGradient) NO_GPU(Subtract) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 9bfdb6ee9..95ff59865 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -81,6 +81,7 @@ bool allows_shapeless(const Primitive& p) { typeid(p) == typeid(NumberOfElements) || typeid(p) == typeid(Gather) || typeid(p) == typeid(Transpose) || typeid(p) == typeid(Concatenate) || typeid(p) == typeid(Matmul) || typeid(p) == typeid(QuantizedMatmul) || + typeid(p) == typeid(Squeeze) || typeid(p) == typeid(ExpandDims) || typeid(p) == typeid(fast::AffineQuantize) || typeid(p) == typeid(fast::LayerNorm) || typeid(p) == typeid(fast::RMSNorm) || typeid(p) == typeid(fast::RoPE) || diff --git a/mlx/ops.cpp b/mlx/ops.cpp index af772ce61..c43aba371 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -20,7 +20,7 @@ namespace mlx::core { namespace { -std::tuple, Shape, bool> compute_reduce_shape( +std::tuple, bool> compute_reduce_shape( const std::vector& axes, const Shape& shape) { bool is_noop = true; @@ -40,18 +40,16 @@ std::tuple, Shape, bool> compute_reduce_shape( throw std::invalid_argument("Duplicate axes detected in reduction."); } Shape out_shape; - Shape squeezed_shape; for (int i = 0; i < ndim; ++i) { if (axes_set.count(i) == 0) { out_shape.push_back(shape[i]); - squeezed_shape.push_back(shape[i]); } else { out_shape.push_back(1); } is_noop &= (out_shape.back() == shape[i]); } std::vector sorted_axes(axes_set.begin(), axes_set.end()); - return {out_shape, sorted_axes, squeezed_shape, is_noop}; + return {out_shape, sorted_axes, is_noop}; } Dtype at_least_float(const Dtype& d) { @@ -460,54 +458,51 @@ array hadamard_transform( {astype(a, dtype, s)}); } +array squeeze_impl( + const array& a, + std::vector axes, + StreamOrDevice s /* = {} */) { + for (auto& ax : axes) { + auto new_ax = ax < 0 ? ax + a.ndim() : ax; + if (new_ax < 0 || new_ax >= a.ndim()) { + std::ostringstream msg; + msg << "[squeeze] Invalid axes " << ax << " for array with " << a.ndim() + << " dimensions."; + throw std::invalid_argument(msg.str()); + } + if (a.shape(new_ax) != 1) { + std::ostringstream msg; + msg << "[squeeze] Cannot squeeze axis " << ax << " with size " + << a.shape(ax) << " which is not equal to 1."; + throw std::invalid_argument(msg.str()); + } + ax = new_ax; + } + auto shape = Squeeze::output_shape(a, axes); + return array( + std::move(shape), + a.dtype(), + std::make_shared(to_stream(s), std::move(axes)), + {a}); +} + array squeeze( const array& a, const std::vector& axes, StreamOrDevice s /* = {} */) { std::set unique_axes; for (auto ax : axes) { - ax = ax < 0 ? ax + a.ndim() : ax; - if (ax < 0 || ax >= a.ndim()) { - std::ostringstream msg; - msg << "[squeeze] Invalid axes " << ax << " for array with " << a.ndim() - << " dimensions."; - throw std::invalid_argument(msg.str()); - } - if (a.shape(ax) != 1) { - std::ostringstream msg; - msg << "[squeeze] Cannot squeeze axis " << ax << " with size " - << a.shape(ax) << " which is not equal to 1."; - throw std::invalid_argument(msg.str()); - } - unique_axes.insert(ax); + unique_axes.insert(ax < 0 ? ax + a.ndim() : ax); } - if (unique_axes.size() != axes.size()) { throw std::invalid_argument("[squeeze] Received duplicate axes."); } std::vector sorted_axes(unique_axes.begin(), unique_axes.end()); - Shape shape; - for (int i = 0, j = 0; i < a.ndim(); ++i) { - if (j < sorted_axes.size() && i == sorted_axes[j]) { - j++; - } else { - shape.push_back(a.shape(i)); - } - } - return reshape(a, std::move(shape), s); + return squeeze_impl(a, std::move(sorted_axes), s); } array squeeze(const array& a, int axis, StreamOrDevice s /* = {} */) { - int ax = axis < 0 ? axis + a.ndim() : axis; - if (ax < 0 || ax >= a.ndim()) { - std::ostringstream msg; - msg << "[squeeze] Invalid axis " << axis << " for array with " << a.ndim() - << " dimensions."; - throw std::invalid_argument(msg.str()); - } - auto shape = a.shape(); - shape.erase(shape.begin() + ax); - return reshape(a, std::move(shape), s); + return squeeze_impl(a, {axis}, s); } array squeeze(const array& a, StreamOrDevice s /* = {} */) { @@ -517,21 +512,34 @@ array squeeze(const array& a, StreamOrDevice s /* = {} */) { axes.push_back(i); } } - return squeeze(a, axes, s); + return squeeze_impl(a, std::move(axes), s); +} + +array expand_dims_impl( + const array& a, + std::vector axes, + StreamOrDevice s /* = {} */) { + auto out_ndim = a.ndim() + axes.size(); + for (auto& ax : axes) { + auto new_ax = ax < 0 ? ax + out_ndim : ax; + if (new_ax < 0 || new_ax >= out_ndim) { + std::ostringstream msg; + msg << "[expand_dims] Invalid axis " << ax << " for output array with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + ax = new_ax; + } + auto shape = ExpandDims::output_shape(a, axes); + return array( + std::move(shape), + a.dtype(), + std::make_shared(to_stream(s), std::move(axes)), + {a}); } array expand_dims(const array& a, int axis, StreamOrDevice s /* = {} */) { - int out_dim = a.ndim() + 1; - int ax = axis < 0 ? axis + out_dim : axis; - if (ax < 0 || ax >= out_dim) { - std::ostringstream msg; - msg << "[expand_dims] Invalid axis " << axis << " for output array with " - << a.ndim() << " dimensions."; - throw std::invalid_argument(msg.str()); - } - auto shape = a.shape(); - shape.insert(shape.begin() + ax, 1); - return reshape(a, std::move(shape), s); + return expand_dims_impl(a, {axis}, s); } array expand_dims( @@ -544,31 +552,17 @@ array expand_dims( throw std::invalid_argument("[expand_dims] Received duplicate axes."); } } - - int out_ndim = axes.size() + a.ndim(); - std::vector canonical_axes = axes; - for (auto& ax : canonical_axes) { - ax = ax < 0 ? ax + out_ndim : ax; - if (ax < 0 || ax >= out_ndim) { - std::ostringstream msg; - msg << "[expand_dims] Invalid axis " << ax << " for output array with " - << a.ndim() << " dimensions."; - throw std::invalid_argument(msg.str()); - } - } - // Check for repeats again - std::set unique_axes(canonical_axes.begin(), canonical_axes.end()); + auto out_ndim = a.ndim() + axes.size(); + std::set unique_axes; + for (auto ax : axes) { + unique_axes.insert(ax < 0 ? ax + out_ndim : ax); + } if (unique_axes.size() != axes.size()) { throw std::invalid_argument("[expand_dims] Received duplicate axes."); } - std::vector sorted_axes(unique_axes.begin(), unique_axes.end()); - auto out_shape = a.shape(); - for (int i = 0; i < sorted_axes.size(); ++i) { - out_shape.insert(out_shape.begin() + sorted_axes[i], 1); - } - return reshape(a, std::move(out_shape), s); + return expand_dims_impl(a, std::move(sorted_axes), s); } // Slice helper @@ -1519,7 +1513,7 @@ array all( const std::vector& axes, bool keepdims /* = false */, StreamOrDevice s /* = {}*/) { - auto [out_shape, sorted_axes, squeezed_shape, is_noop] = + auto [out_shape, sorted_axes, is_noop] = compute_reduce_shape(axes, a.shape()); auto out = (is_noop) ? astype(a, bool_, s) @@ -1529,7 +1523,7 @@ array all( std::make_shared(to_stream(s), Reduce::And, sorted_axes), {a}); if (!keepdims) { - out = reshape(out, std::move(squeezed_shape), s); + out = squeeze(out, sorted_axes, s); } return out; } @@ -1553,7 +1547,7 @@ array any( const std::vector& axes, bool keepdims /* = false */, StreamOrDevice s /* = {}*/) { - auto [out_shape, sorted_axes, squeezed_shape, is_noop] = + auto [out_shape, sorted_axes, is_noop] = compute_reduce_shape(axes, a.shape()); auto out = (is_noop) ? astype(a, bool_, s) @@ -1563,7 +1557,7 @@ array any( std::make_shared(to_stream(s), Reduce::Or, sorted_axes), {a}); if (!keepdims) { - out = reshape(out, std::move(squeezed_shape), s); + out = squeeze(out, sorted_axes, s); } return out; } @@ -1590,7 +1584,7 @@ array sum( if (axes.empty()) { return a; } - auto [out_shape, sorted_axes, squeezed_shape, is_noop] = + auto [out_shape, sorted_axes, is_noop] = compute_reduce_shape(axes, a.shape()); Dtype out_type = a.dtype(); if (issubdtype(a.dtype(), signedinteger)) { @@ -1608,7 +1602,7 @@ array sum( std::make_shared(to_stream(s), Reduce::Sum, sorted_axes), {a}); if (!keepdims) { - out = reshape(out, std::move(squeezed_shape), s); + out = squeeze(out, sorted_axes, s); } return out; } @@ -1742,7 +1736,7 @@ array prod( if (axes.empty()) { return a; } - auto [out_shape, sorted_axes, squeezed_shape, is_noop] = + auto [out_shape, sorted_axes, is_noop] = compute_reduce_shape(axes, a.shape()); Dtype out_type = a.dtype(); if (issubdtype(a.dtype(), signedinteger)) { @@ -1760,7 +1754,7 @@ array prod( std::make_shared(to_stream(s), Reduce::Prod, sorted_axes), {a}); if (!keepdims) { - out = reshape(out, std::move(squeezed_shape), s); + out = squeeze(out, sorted_axes, s); } return out; } @@ -1787,7 +1781,7 @@ array max( if (a.size() == 0) { throw std::invalid_argument("[max] Cannot max reduce zero size array."); } - auto [out_shape, sorted_axes, squeezed_shape, is_noop] = + auto [out_shape, sorted_axes, is_noop] = compute_reduce_shape(axes, a.shape()); auto out = (is_noop) ? a @@ -1797,7 +1791,7 @@ array max( std::make_shared(to_stream(s), Reduce::Max, sorted_axes), {a}); if (!keepdims) { - out = reshape(out, std::move(squeezed_shape), s); + out = squeeze(out, sorted_axes, s); } return out; } @@ -1827,7 +1821,7 @@ array min( if (axes.empty()) { return a; } - auto [out_shape, sorted_axes, squeezed_shape, is_noop] = + auto [out_shape, sorted_axes, is_noop] = compute_reduce_shape(axes, a.shape()); auto out = (is_noop) ? a @@ -1837,7 +1831,7 @@ array min( std::make_shared(to_stream(s), Reduce::Min, sorted_axes), {a}); if (!keepdims) { - out = reshape(out, std::move(squeezed_shape), s); + out = squeeze(out, sorted_axes, s); } return out; } @@ -1870,7 +1864,7 @@ array argmin( throw std::invalid_argument( "[argmin] Cannot argmin reduce zero size array."); } - auto [out_shape, sorted_axes, squeezed_shape, is_noop] = + auto [out_shape, sorted_axes, is_noop] = compute_reduce_shape({axis}, a.shape()); auto out = (is_noop) ? zeros(out_shape, uint32, s) @@ -1881,7 +1875,7 @@ array argmin( to_stream(s), ArgReduce::ArgMin, sorted_axes[0]), {a}); if (!keepdims) { - out = reshape(out, std::move(squeezed_shape), s); + out = squeeze(out, sorted_axes[0], s); } return out; } @@ -1906,7 +1900,7 @@ array argmax( throw std::invalid_argument( "[argmax] Cannot argmax reduce zero size array."); } - auto [out_shape, sorted_axes, squeezed_shape, is_noop] = + auto [out_shape, sorted_axes, is_noop] = compute_reduce_shape({axis}, a.shape()); auto out = (is_noop) ? zeros(out_shape, uint32, s) @@ -1917,7 +1911,7 @@ array argmax( to_stream(s), ArgReduce::ArgMax, sorted_axes[0]), {a}); if (!keepdims) { - out = reshape(out, std::move(squeezed_shape), s); + out = squeeze(out, sorted_axes[0], s); } return out; } @@ -2544,11 +2538,11 @@ array matmul( } if (a.ndim() == 1) { // Insert a singleton dim in the beginning - a = reshape(a, {1, -1}, s); + a = expand_dims(a, 0, s); } if (b.ndim() == 1) { // Insert a singleton dim at the end - b = reshape(b, {-1, 1}, s); + b = expand_dims(b, 1, s); } if (a.shape(-1) != b.shape(-2)) { std::ostringstream msg; @@ -2608,17 +2602,21 @@ array matmul( auto out_shape = a.shape(); out_shape.back() = b.shape(-1); - auto p = std::make_shared(to_stream(s)); + auto out = array( + std::move(out_shape), + out_type, + std::make_shared(to_stream(s)), + {a, b}); // Remove the possibly inserted singleton dimensions - if (in_a.ndim() == 1 || in_b.ndim() == 1) { - auto out = array(out_shape, out_type, std::move(p), {a, b}); - out_shape.erase( - out_shape.end() - ((in_a.ndim() == 1) ? 2 : 1), - out_shape.end() - ((in_b.ndim() == 1) ? 0 : 1)); - return reshape(out, std::move(out_shape), s); + std::vector axes; + if (in_a.ndim() == 1) { + axes.push_back(out.ndim() - 2); } - return array(std::move(out_shape), out_type, std::move(p), {a, b}); + if (in_b.ndim() == 1) { + axes.push_back(out.ndim() - 1); + } + return axes.empty() ? out : squeeze(out, axes, s); } array gather( @@ -2658,15 +2656,6 @@ array gather( << " for array with " << a.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } - for (int i = 0; i < a.ndim(); ++i) { - if (slice_sizes[i] < 0 || slice_sizes[i] > a.shape(i)) { - std::ostringstream msg; - msg << "[gather] Slice sizes must be in [0, a.shape(i)]. Got " - << slice_sizes << " for array with shape " << a.shape() << "."; - throw std::invalid_argument(msg.str()); - } - } - // Promote indices to the same type auto dtype = result_type(indices); if (issubdtype(dtype, inexact)) { @@ -2680,6 +2669,29 @@ array gather( idx = astype(idx, dtype, s); } + if (a.size() == 0) { + // Empty input, either the total slice size is 0 or the indices are empty + auto total_slice = std::accumulate( + slice_sizes.begin(), slice_sizes.end(), 1, std::multiplies{}); + auto idx_size = !inputs.empty() ? inputs[0].size() : 1; + if (idx_size != 0 && total_slice != 0) { + std::ostringstream msg; + msg << "[gather] If the input is empty, either the indices must be" + << " empty or the total slice size must be 0."; + throw std::invalid_argument(msg.str()); + } + } else { + // Non-empty input, check slice sizes are valid + for (int i = 0; i < a.ndim(); ++i) { + if (slice_sizes[i] < 0 || slice_sizes[i] > a.shape(i)) { + std::ostringstream msg; + msg << "[gather] Slice sizes must be in [0, a.shape(i)]. Got " + << slice_sizes << " for array with shape " << a.shape() << "."; + throw std::invalid_argument(msg.str()); + } + } + } + Shape out_shape; if (!inputs.empty()) { out_shape = inputs[0].shape(); @@ -2688,9 +2700,10 @@ array gather( inputs.insert(inputs.begin(), a); return array( - out_shape, + std::move(out_shape), a.dtype(), - std::make_shared(to_stream(s), axes, slice_sizes), + std::make_shared( + to_stream(s), std::move(axes), std::move(slice_sizes)), inputs); } @@ -2719,7 +2732,7 @@ array take( // Make slice sizes to pass to gather Shape slice_sizes = a.shape(); - slice_sizes[axis] = indices.size() > 0 ? 1 : 0; + slice_sizes[axis] = 1; auto out = gather(a, indices, axis, slice_sizes, s); @@ -2736,9 +2749,7 @@ array take( } // Squeeze the axis we take over - auto out_shape = out.shape(); - out_shape.erase(out_shape.begin() + indices.ndim() + axis); - return reshape(out, std::move(out_shape), s); + return squeeze(out, indices.ndim() + axis, s); } array take(const array& a, const array& indices, StreamOrDevice s /* = {} */) { @@ -2811,12 +2822,14 @@ array take_along_axis( } std::vector dims(a.ndim()); std::iota(dims.begin(), dims.end(), 0); - Shape slice_sizes(a.ndim(), a.size() > 0); + Shape slice_sizes(a.ndim(), 1); auto out = gather(a, nd_indices, dims, slice_sizes, s); // Squeeze out the slice shape - Shape out_shape(out.shape().begin(), out.shape().begin() + a.ndim()); - return reshape(out, std::move(out_shape), s); + for (auto& d : dims) { + d += a.ndim(); + } + return squeeze(out, dims, s); } array put_along_axis( @@ -3935,17 +3948,20 @@ array addmm( } auto out = array( - out_shape, + std::move(out_shape), out_type, std::make_shared(to_stream(s), alpha, beta), {a, b, c}); // Remove the possibly inserted singleton dimensions - if (in_a_ndim == 1 || in_b_ndim == 1) { - out = reshape(out, out_shape_adjusted, s); + std::vector axes; + if (in_a_ndim == 1) { + axes.push_back(out.ndim() - 2); } - - return out; + if (in_b_ndim == 1) { + axes.push_back(out.ndim() - 1); + } + return axes.empty() ? out : squeeze(out, axes, s); } /** Compute matrix product with tile-level masking */ @@ -3986,11 +4002,11 @@ array block_masked_mm( if (a.ndim() == 1) { // Insert a singleton dim in the beginning - a = reshape(a, {1, -1}, s); + a = expand_dims(a, 0, s); } if (b.ndim() == 1) { // Insert a singleton dim at the end - b = reshape(b, {-1, 1}, s); + b = expand_dims(b, 1, s); } if (a.shape(-1) != b.shape(-2)) { @@ -4110,20 +4126,19 @@ array block_masked_mm( // Caculate array auto out = array( - out_shape, + std::move(out_shape), out_type, std::make_shared(to_stream(s), block_size), std::move(inputs)); - // Remove the possibly inserted singleton dimensions - if (in_a_ndim == 1 || in_b_ndim == 1) { - out_shape.erase( - out_shape.end() - ((in_a_ndim == 1) ? 2 : 1), - out_shape.end() - ((in_b_ndim == 1) ? 0 : 1)); - out = reshape(out, out_shape, s); + std::vector axes; + if (in_a_ndim == 1) { + axes.push_back(out.ndim() - 2); } - - return out; + if (in_b_ndim == 1) { + axes.push_back(out.ndim() - 1); + } + return axes.empty() ? out : squeeze(out, axes, s); } /** Compute matrix product with matrix-level gather */ @@ -4150,11 +4165,11 @@ array gather_mm( if (a.ndim() == 1) { // Insert a singleton dim in the beginning - a = reshape(a, {1, -1}, s); + a = expand_dims(a, 0, s); } if (b.ndim() == 1) { // Insert a singleton dim at the end - b = reshape(b, {-1, 1}, s); + b = expand_dims(b, 1, s); } if (a.shape(-1) != b.shape(-2)) { @@ -4212,20 +4227,20 @@ array gather_mm( // Caculate array auto out = array( - out_shape, + std::move(out_shape), out_type, std::make_shared(to_stream(s)), {a, b, lhs_indices, rhs_indices}); // Remove the possibly inserted singleton dimensions - if (in_a_ndim == 1 || in_b_ndim == 1) { - out_shape.erase( - out_shape.end() - ((in_a_ndim == 1) ? 2 : 1), - out_shape.end() - ((in_b_ndim == 1) ? 0 : 1)); - out = reshape(out, out_shape, s); + std::vector axes; + if (in_a_ndim == 1) { + axes.push_back(out.ndim() - 2); } - - return out; + if (in_b_ndim == 1) { + axes.push_back(out.ndim() - 1); + } + return axes.empty() ? out : squeeze(out, axes, s); } array diagonal( diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index ab1c1f03b..e139f7385 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1602,6 +1602,55 @@ std::pair, std::vector> Expm1::vmap( return {{expm1(inputs[0], stream())}, axes}; } +std::vector ExpandDims::vjp( + const std::vector&, + const std::vector& cotangents, + const std::vector&, + const std::vector&) { + return {squeeze(cotangents[0], axes_, stream())}; +} + +std::vector ExpandDims::jvp( + const std::vector&, + const std::vector& tangents, + const std::vector&) { + return {expand_dims(tangents[0], axes_, stream())}; +} + +std::pair, std::vector> ExpandDims::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto ax = axes[0]; + auto expand_axes = axes_; + for (auto& s : expand_axes) { + if (s >= axes[0]) { + s++; + } else { + ax++; + } + } + return {{expand_dims(inputs[0], std::move(expand_axes), stream())}, {ax}}; +} + +bool ExpandDims::is_equivalent(const Primitive& other) const { + const ExpandDims& a_other = static_cast(other); + return (axes_ == a_other.axes_); +} + +Shape ExpandDims::output_shape( + const array& input, + const std::vector& axes) { + auto shape = input.shape(); + for (auto ax : axes) { + shape.insert(shape.begin() + ax, 1); + } + return shape; +} + +std::vector ExpandDims::output_shapes(const std::vector& inputs) { + return {ExpandDims::output_shape(inputs[0], axes_)}; +} + bool FFT::is_equivalent(const Primitive& other) const { const FFT& r_other = static_cast(other); return axes_ == r_other.axes_ && inverse_ == r_other.inverse_ && @@ -1846,7 +1895,7 @@ bool Gather::is_equivalent(const Primitive& other) const { std::vector Gather::output_shapes(const std::vector& inputs) { Shape out_shape; if (inputs.size() > 1) { - out_shape = inputs[0].shape(); + out_shape = inputs[1].shape(); } out_shape.insert(out_shape.end(), slice_sizes_.begin(), slice_sizes_.end()); return {std::move(out_shape)}; @@ -3847,6 +3896,57 @@ std::pair, std::vector> Subtract::vmap( return {{subtract(a, b, stream())}, {to_ax}}; } +std::vector Squeeze::vjp( + const std::vector&, + const std::vector& cotangents, + const std::vector&, + const std::vector&) { + return {expand_dims(cotangents[0], axes_, stream())}; +} + +std::vector Squeeze::jvp( + const std::vector&, + const std::vector& tangents, + const std::vector&) { + return {squeeze(tangents[0], axes_, stream())}; +} + +std::pair, std::vector> Squeeze::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto ax = axes[0]; + auto squeeze_axes = axes_; + for (auto& s : squeeze_axes) { + if (s >= axes[0]) { + s++; + } else { + ax--; + } + } + return {{squeeze(inputs[0], std::move(squeeze_axes), stream())}, {ax}}; +} + +bool Squeeze::is_equivalent(const Primitive& other) const { + const Squeeze& a_other = static_cast(other); + return (axes_ == a_other.axes_); +} + +Shape Squeeze::output_shape(const array& input, const std::vector& axes) { + Shape shape; + for (int i = 0, j = 0; i < input.ndim(); ++i) { + if (j < axes.size() && i == axes[j]) { + j++; + } else { + shape.push_back(input.shape(i)); + } + } + return shape; +} + +std::vector Squeeze::output_shapes(const std::vector& inputs) { + return {Squeeze::output_shape(inputs[0], axes_)}; +} + std::vector Tan::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index a166f164c..1ca913a37 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -983,6 +983,28 @@ class Expm1 : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; +class ExpandDims : public UnaryPrimitive { + public: + explicit ExpandDims(Stream stream, std::vector axes) + : UnaryPrimitive(stream), axes_(std::move(axes)) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_PRINT(ExpandDims) + + std::vector output_shapes(const std::vector& inputs) override; + bool is_equivalent(const Primitive& other) const override; + + static Shape output_shape(const array& input, const std::vector& axes); + + private: + void eval(const std::vector& inputs, array& out); + std::vector axes_; +}; + class FFT : public UnaryPrimitive { public: explicit FFT( @@ -1046,9 +1068,11 @@ class Gather : public UnaryPrimitive { public: explicit Gather( Stream stream, - const std::vector& axes, - const std::vector& slice_sizes) - : UnaryPrimitive(stream), axes_(axes), slice_sizes_(slice_sizes) {} + std::vector axes, + std::vector slice_sizes) + : UnaryPrimitive(stream), + axes_(std::move(axes)), + slice_sizes_(std::move(slice_sizes)) {} void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -2057,6 +2081,28 @@ class Subtract : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; +class Squeeze : public UnaryPrimitive { + public: + explicit Squeeze(Stream stream, std::vector axes) + : UnaryPrimitive(stream), axes_(std::move(axes)) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_PRINT(Squeeze) + + std::vector output_shapes(const std::vector& inputs) override; + bool is_equivalent(const Primitive& other) const override; + + static Shape output_shape(const array& input, const std::vector& axes); + + private: + void eval(const std::vector& inputs, array& out); + std::vector axes_; +}; + class Tan : public UnaryPrimitive { public: explicit Tan(Stream stream) : UnaryPrimitive(stream) {} diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index 197441da9..d092f30c2 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -144,23 +144,23 @@ array mlx_gather_nd( int slice_index = 0; for (int i = 0; i < gather_indices.size(); i++) { if (is_slice[i]) { - std::vector index_shape(max_dims + num_slices, 1); + Shape index_shape(max_dims + num_slices, 1); index_shape[max_dims + slice_index] = gather_indices[i].shape(0); - gather_indices[i] = reshape(gather_indices[i], index_shape); + gather_indices[i] = reshape(gather_indices[i], std::move(index_shape)); slice_index++; } else { - std::vector index_shape = gather_indices[i].shape(); + auto index_shape = gather_indices[i].shape(); index_shape.insert(index_shape.end(), num_slices, 1); - gather_indices[i] = reshape(gather_indices[i], index_shape); + gather_indices[i] = reshape(gather_indices[i], std::move(index_shape)); } } } else { // reshape them so that the int/array indices are last for (int i = 0; i < gather_indices.size(); i++) { if (i < num_slices) { - std::vector index_shape(max_dims + num_slices, 1); + Shape index_shape(max_dims + num_slices, 1); index_shape[i] = gather_indices[i].shape(0); - gather_indices[i] = reshape(gather_indices[i], index_shape); + gather_indices[i] = reshape(gather_indices[i], std::move(index_shape)); } } } @@ -172,19 +172,11 @@ array mlx_gather_nd( std::fill(slice_sizes.begin(), slice_sizes.begin() + indices.size(), 1); src = gather(src, gather_indices, axes, slice_sizes); - // Squeeze the dims - std::vector out_shape; - out_shape.insert( - out_shape.end(), - src.shape().begin(), - src.shape().begin() + max_dims + num_slices); - out_shape.insert( - out_shape.end(), - src.shape().begin() + max_dims + num_slices + indices.size(), - src.shape().end()); - src = reshape(src, out_shape); - - return src; + // Squeeze the array index dims + for (auto& ax : axes) { + ax += max_dims + num_slices; + } + return squeeze(src, axes); } auto mlx_expand_ellipsis( diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index feb8e6da6..c0d5e3647 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -392,27 +392,6 @@ class TestCompile(mlx_tests.MLXTestCase): out = fun(x, y=y, z=z) self.assertEqual(out.item(), 6) - def test_shapeless_compile(self): - y = 1 - - @partial(mx.compile, shapeless=True) - def fun(x): - return x + y - - x = mx.array([1, 2]) - self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3]))) - - # The function is not recompiled, so the change - # to y should not be reflected in the output - y = 2 - x = mx.array([1, 2, 3]) - self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3, 4]))) - - # Type change recompiles - x = mx.array([1.0, 2.0, 3.0]) - self.assertTrue(mx.array_equal(fun(x), mx.array([3.0, 4.0, 5.0]))) - fun(x, y=y, z=z) - def test_shapeless_compile(self): y = 1 @@ -477,6 +456,12 @@ class TestCompile(mlx_tests.MLXTestCase): mx.eval(cfun(x1)) self.assertTrue(mx.array_equal(fun(x2), cfun(x2))) + def fun(x): + return x * x.sum(-1, keepdims=False) + + cfun = mx.compile(fun, shapeless=True) + self.assertTrue(mx.array_equal(fun(x2), cfun(x2))) + def test_compile_with_constant(self): # Test float @partial(mx.compile) @@ -809,6 +794,13 @@ class TestCompile(mlx_tests.MLXTestCase): out = fun(*inputs) self.assertTrue(mx.allclose(out, mx.full((2, 2), 20))) + def test_shapeless_compile_matmul(self): + a = mx.array([0.0, 1.0, 2.0]) + b = mx.array([0.0, 1.0, 2.0]) + + fun = mx.compile(lambda a, b: a @ b, shapeless=True) + self.assertTrue(mx.allclose(fun(a, b), a @ b)) + if __name__ == "__main__": unittest.main() diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 545f5e24c..a3638cfec 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1835,6 +1835,9 @@ TEST_CASE("test broadcast") { } TEST_CASE("test gather") { + // Empty input, non-empty indices/slice + CHECK_THROWS(gather(array({}), array({1}), 0, {1})); + // More indices than dimensions CHECK_THROWS(gather(array(0), array({1}), 0, {1}));