Use expand_dims / unflatten / etc in more places (#1696)

* use expand_dims / unflatten in a couple more places

* few more

* few more

* fix
This commit is contained in:
Awni Hannun 2024-12-12 17:00:44 -08:00 committed by GitHub
parent 9111999af3
commit 50f3535693
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 43 additions and 45 deletions

View File

@ -620,10 +620,15 @@ array scaled_dot_product_attention(
}
auto scores = matmul(q, swapaxes(k, -1, -2, s), s);
if (inputs.size() > 3) {
auto mask_shape = inputs[0].shape();
mask_shape.back() = inputs[1].shape(-2);
auto mask = reshape(
broadcast_to(inputs[3], std::move(mask_shape), s), scores.shape(), s);
// Mask must be broadcast-compatible with [B, n_q_heads, L_q, L_kv]
auto mask = inputs[3];
if (n_repeats > 1 && mask.ndim() >= 3) {
if (mask.shape(-3) == 1) {
mask = expand_dims(mask, -3, s);
} else {
mask = unflatten(mask, -3, {n_kv_heads, n_repeats}, s);
}
}
scores = add(scores, mask, s);
}
scores = softmax(scores, std::vector<int>{-1}, true, s);

View File

@ -542,6 +542,9 @@ array squeeze(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s /* = {} */) {
if (axes.empty()) {
return a;
}
std::set<int> unique_axes;
for (auto ax : axes) {
unique_axes.insert(ax < 0 ? ax + a.ndim() : ax);
@ -598,6 +601,9 @@ array expand_dims(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s /* = {} */) {
if (axes.empty()) {
return a;
}
{ // Check for repeats
std::set<int> unique_axes(axes.begin(), axes.end());
if (unique_axes.size() != axes.size()) {

View File

@ -176,7 +176,7 @@ mx::array mlx_gather_nd(
for (auto& ax : axes) {
ax += max_dims + num_slices;
}
return squeeze(src, axes);
return mx::squeeze(src, axes);
}
auto mlx_expand_ellipsis(
@ -438,9 +438,7 @@ mx::array mlx_get_item(const mx::array& src, const nb::object& obj) {
} else if (nb::isinstance<nb::ellipsis>(obj)) {
return src;
} else if (obj.is_none()) {
std::vector<int> s(1, 1);
s.insert(s.end(), src.shape().begin(), src.shape().end());
return reshape(src, s);
return expand_dims(src, 0);
} else if (nb::isinstance<nb::list>(obj)) {
return mlx_get_item_array(
src, array_from_list(nb::cast<nb::list>(obj), {}));
@ -474,6 +472,15 @@ mlx_scatter_args_int(
{0}};
}
mx::array squeeze_leading_singletons(const mx::array& in) {
int s = 0;
for (; s < in.ndim() && in.shape(s) == 1; s++)
;
auto squeeze_axes = std::vector<int>(s);
std::iota(squeeze_axes.begin(), squeeze_axes.end(), 0);
return mx::squeeze(in, squeeze_axes);
}
std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
mlx_scatter_args_array(
const mx::array& src,
@ -484,16 +491,10 @@ mlx_scatter_args_array(
"too many indices for array: array is 0-dimensional");
}
// Remove any leading singleton dimensions from the update
int s = 0;
for (; s < update.ndim() && update.shape(s) == 1; s++)
;
auto up_shape =
std::vector<int>(update.shape().begin() + s, update.shape().end());
auto up = reshape(update, up_shape);
auto up = squeeze_leading_singletons(update);
// The update shape must broadcast with indices.shape + [1] + src.shape[1:]
up_shape = indices.shape();
auto up_shape = indices.shape();
up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end());
up = broadcast_to(up, up_shape);
up_shape.insert(up_shape.begin() + indices.ndim(), 1);
@ -516,12 +517,8 @@ mlx_scatter_args_slice(
// If none slice is requested broadcast the update
// to the src size and return it.
if (is_none_slice(in_slice)) {
int s = 0;
for (; s < update.ndim() && update.shape(s) == 1; s++)
;
auto up_shape =
std::vector<int>(update.shape().begin() + s, update.shape().end());
return {{}, broadcast_to(reshape(update, up_shape), src.shape()), {}};
return {
{}, broadcast_to(squeeze_leading_singletons(update), src.shape()), {}};
}
int start = 0;
@ -534,12 +531,7 @@ mlx_scatter_args_slice(
// If simple stride
if (stride == 1) {
// Squeeze out singleton dims from the start of update
int s = 0;
for (; s < update.ndim() && update.shape(s) == 1; s++)
;
auto up_shape =
std::vector<int>(update.shape().begin() + s, update.shape().end());
auto up = reshape(update, up_shape);
auto up = squeeze_leading_singletons(update);
// Build array to mark start of slice
auto idx = mx::array({start}, {1}, mx::uint32);
@ -548,7 +540,7 @@ mlx_scatter_args_slice(
int slice_size = (end - start);
// Broadcast update to slice size
std::vector<int> up_shape_broadcast = {1, slice_size};
mx::Shape up_shape_broadcast = {1, slice_size};
up_shape_broadcast.insert(
up_shape_broadcast.end(), src.shape().begin() + 1, src.shape().end());
@ -585,13 +577,7 @@ mlx_scatter_args_nd(
throw std::invalid_argument(msg.str());
}
// Remove leading singletons dimensions from the update
int s = 0;
for (; s < update.ndim() && update.shape(s) == 1; s++) {
};
auto up_shape =
std::vector<int>(update.shape().begin() + s, update.shape().end());
auto up = reshape(update, up_shape);
auto up = squeeze_leading_singletons(update);
// If no non-None indices return the broadcasted update
if (non_none_indices == 0) {
@ -703,7 +689,7 @@ mlx_scatter_args_nd(
} else if (nb::isinstance<mx::array>(pyidx)) {
ax++;
auto idx = nb::cast<mx::array>(pyidx);
std::vector<int> idx_shape(idx_ndim, 1);
mx::Shape idx_shape(idx_ndim, 1);
// Place the arrays in the correct dimension
int st = (!arrays_first) * slice_num + max_dim - idx.ndim();
@ -801,17 +787,18 @@ auto mlx_slice_update(
// Remove extra leading singletons dimensions from the update
int s = 0;
for (; s < upd.ndim() && upd.shape(s) == 1 && (upd.ndim() - s) > src.ndim();
for (; s < static_cast<int>(upd.ndim()) - 1 && upd.shape(s) == 1 &&
(upd.ndim() - s) > src.ndim();
s++) {
};
auto up_shape = std::vector<int>(upd.shape().begin() + s, upd.shape().end());
up_shape = up_shape.empty() ? std::vector{1} : up_shape;
auto up = reshape(upd, up_shape);
auto squeeze_axes = std::vector<int>(s);
std::iota(squeeze_axes.begin(), squeeze_axes.end(), 0);
auto up = mx::squeeze(upd, squeeze_axes);
// Build slice update params
std::vector<int> starts(src.ndim(), 0);
std::vector<int> stops = src.shape();
std::vector<int> strides(src.ndim(), 1);
mx::Shape starts(src.ndim(), 0);
mx::Shape stops = src.shape();
mx::Shape strides(src.ndim(), 1);
// If it's just a simple slice, just do a slice update and return
if (nb::isinstance<nb::slice>(obj)) {
@ -847,7 +834,7 @@ auto mlx_slice_update(
}
// Process entries
std::vector<int> up_reshape(src.ndim());
mx::Shape up_reshape(src.ndim());
int ax = src.ndim() - 1;
int up_ax = up.ndim() - 1;
for (; ax >= non_none_indices; ax--) {