diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 58a0c32a6..58800fb70 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -620,10 +620,15 @@ array scaled_dot_product_attention( } auto scores = matmul(q, swapaxes(k, -1, -2, s), s); if (inputs.size() > 3) { - auto mask_shape = inputs[0].shape(); - mask_shape.back() = inputs[1].shape(-2); - auto mask = reshape( - broadcast_to(inputs[3], std::move(mask_shape), s), scores.shape(), s); + // Mask must be broadcast-compatible with [B, n_q_heads, L_q, L_kv] + auto mask = inputs[3]; + if (n_repeats > 1 && mask.ndim() >= 3) { + if (mask.shape(-3) == 1) { + mask = expand_dims(mask, -3, s); + } else { + mask = unflatten(mask, -3, {n_kv_heads, n_repeats}, s); + } + } scores = add(scores, mask, s); } scores = softmax(scores, std::vector{-1}, true, s); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 8433b4129..a0a259580 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -542,6 +542,9 @@ array squeeze( const array& a, const std::vector& axes, StreamOrDevice s /* = {} */) { + if (axes.empty()) { + return a; + } std::set unique_axes; for (auto ax : axes) { unique_axes.insert(ax < 0 ? ax + a.ndim() : ax); @@ -598,6 +601,9 @@ array expand_dims( const array& a, const std::vector& axes, StreamOrDevice s /* = {} */) { + if (axes.empty()) { + return a; + } { // Check for repeats std::set unique_axes(axes.begin(), axes.end()); if (unique_axes.size() != axes.size()) { diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index 6278f5b99..3c042ce33 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -176,7 +176,7 @@ mx::array mlx_gather_nd( for (auto& ax : axes) { ax += max_dims + num_slices; } - return squeeze(src, axes); + return mx::squeeze(src, axes); } auto mlx_expand_ellipsis( @@ -438,9 +438,7 @@ mx::array mlx_get_item(const mx::array& src, const nb::object& obj) { } else if (nb::isinstance(obj)) { return src; } else if (obj.is_none()) { - std::vector s(1, 1); - s.insert(s.end(), src.shape().begin(), src.shape().end()); - return reshape(src, s); + return expand_dims(src, 0); } else if (nb::isinstance(obj)) { return mlx_get_item_array( src, array_from_list(nb::cast(obj), {})); @@ -474,6 +472,15 @@ mlx_scatter_args_int( {0}}; } +mx::array squeeze_leading_singletons(const mx::array& in) { + int s = 0; + for (; s < in.ndim() && in.shape(s) == 1; s++) + ; + auto squeeze_axes = std::vector(s); + std::iota(squeeze_axes.begin(), squeeze_axes.end(), 0); + return mx::squeeze(in, squeeze_axes); +} + std::tuple, mx::array, std::vector> mlx_scatter_args_array( const mx::array& src, @@ -484,16 +491,10 @@ mlx_scatter_args_array( "too many indices for array: array is 0-dimensional"); } - // Remove any leading singleton dimensions from the update - int s = 0; - for (; s < update.ndim() && update.shape(s) == 1; s++) - ; - auto up_shape = - std::vector(update.shape().begin() + s, update.shape().end()); - auto up = reshape(update, up_shape); + auto up = squeeze_leading_singletons(update); // The update shape must broadcast with indices.shape + [1] + src.shape[1:] - up_shape = indices.shape(); + auto up_shape = indices.shape(); up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end()); up = broadcast_to(up, up_shape); up_shape.insert(up_shape.begin() + indices.ndim(), 1); @@ -516,12 +517,8 @@ mlx_scatter_args_slice( // If none slice is requested broadcast the update // to the src size and return it. if (is_none_slice(in_slice)) { - int s = 0; - for (; s < update.ndim() && update.shape(s) == 1; s++) - ; - auto up_shape = - std::vector(update.shape().begin() + s, update.shape().end()); - return {{}, broadcast_to(reshape(update, up_shape), src.shape()), {}}; + return { + {}, broadcast_to(squeeze_leading_singletons(update), src.shape()), {}}; } int start = 0; @@ -534,12 +531,7 @@ mlx_scatter_args_slice( // If simple stride if (stride == 1) { // Squeeze out singleton dims from the start of update - int s = 0; - for (; s < update.ndim() && update.shape(s) == 1; s++) - ; - auto up_shape = - std::vector(update.shape().begin() + s, update.shape().end()); - auto up = reshape(update, up_shape); + auto up = squeeze_leading_singletons(update); // Build array to mark start of slice auto idx = mx::array({start}, {1}, mx::uint32); @@ -548,7 +540,7 @@ mlx_scatter_args_slice( int slice_size = (end - start); // Broadcast update to slice size - std::vector up_shape_broadcast = {1, slice_size}; + mx::Shape up_shape_broadcast = {1, slice_size}; up_shape_broadcast.insert( up_shape_broadcast.end(), src.shape().begin() + 1, src.shape().end()); @@ -585,13 +577,7 @@ mlx_scatter_args_nd( throw std::invalid_argument(msg.str()); } - // Remove leading singletons dimensions from the update - int s = 0; - for (; s < update.ndim() && update.shape(s) == 1; s++) { - }; - auto up_shape = - std::vector(update.shape().begin() + s, update.shape().end()); - auto up = reshape(update, up_shape); + auto up = squeeze_leading_singletons(update); // If no non-None indices return the broadcasted update if (non_none_indices == 0) { @@ -703,7 +689,7 @@ mlx_scatter_args_nd( } else if (nb::isinstance(pyidx)) { ax++; auto idx = nb::cast(pyidx); - std::vector idx_shape(idx_ndim, 1); + mx::Shape idx_shape(idx_ndim, 1); // Place the arrays in the correct dimension int st = (!arrays_first) * slice_num + max_dim - idx.ndim(); @@ -801,17 +787,18 @@ auto mlx_slice_update( // Remove extra leading singletons dimensions from the update int s = 0; - for (; s < upd.ndim() && upd.shape(s) == 1 && (upd.ndim() - s) > src.ndim(); + for (; s < static_cast(upd.ndim()) - 1 && upd.shape(s) == 1 && + (upd.ndim() - s) > src.ndim(); s++) { }; - auto up_shape = std::vector(upd.shape().begin() + s, upd.shape().end()); - up_shape = up_shape.empty() ? std::vector{1} : up_shape; - auto up = reshape(upd, up_shape); + auto squeeze_axes = std::vector(s); + std::iota(squeeze_axes.begin(), squeeze_axes.end(), 0); + auto up = mx::squeeze(upd, squeeze_axes); // Build slice update params - std::vector starts(src.ndim(), 0); - std::vector stops = src.shape(); - std::vector strides(src.ndim(), 1); + mx::Shape starts(src.ndim(), 0); + mx::Shape stops = src.shape(); + mx::Shape strides(src.ndim(), 1); // If it's just a simple slice, just do a slice update and return if (nb::isinstance(obj)) { @@ -847,7 +834,7 @@ auto mlx_slice_update( } // Process entries - std::vector up_reshape(src.ndim()); + mx::Shape up_reshape(src.ndim()); int ax = src.ndim() - 1; int up_ax = up.ndim() - 1; for (; ax >= non_none_indices; ax--) {