mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Use expand_dims / unflatten / etc in more places (#1696)
* use expand_dims / unflatten in a couple more places * few more * few more * fix
This commit is contained in:
parent
9111999af3
commit
50f3535693
13
mlx/fast.cpp
13
mlx/fast.cpp
@ -620,10 +620,15 @@ array scaled_dot_product_attention(
|
||||
}
|
||||
auto scores = matmul(q, swapaxes(k, -1, -2, s), s);
|
||||
if (inputs.size() > 3) {
|
||||
auto mask_shape = inputs[0].shape();
|
||||
mask_shape.back() = inputs[1].shape(-2);
|
||||
auto mask = reshape(
|
||||
broadcast_to(inputs[3], std::move(mask_shape), s), scores.shape(), s);
|
||||
// Mask must be broadcast-compatible with [B, n_q_heads, L_q, L_kv]
|
||||
auto mask = inputs[3];
|
||||
if (n_repeats > 1 && mask.ndim() >= 3) {
|
||||
if (mask.shape(-3) == 1) {
|
||||
mask = expand_dims(mask, -3, s);
|
||||
} else {
|
||||
mask = unflatten(mask, -3, {n_kv_heads, n_repeats}, s);
|
||||
}
|
||||
}
|
||||
scores = add(scores, mask, s);
|
||||
}
|
||||
scores = softmax(scores, std::vector<int>{-1}, true, s);
|
||||
|
@ -542,6 +542,9 @@ array squeeze(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (axes.empty()) {
|
||||
return a;
|
||||
}
|
||||
std::set<int> unique_axes;
|
||||
for (auto ax : axes) {
|
||||
unique_axes.insert(ax < 0 ? ax + a.ndim() : ax);
|
||||
@ -598,6 +601,9 @@ array expand_dims(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (axes.empty()) {
|
||||
return a;
|
||||
}
|
||||
{ // Check for repeats
|
||||
std::set<int> unique_axes(axes.begin(), axes.end());
|
||||
if (unique_axes.size() != axes.size()) {
|
||||
|
@ -176,7 +176,7 @@ mx::array mlx_gather_nd(
|
||||
for (auto& ax : axes) {
|
||||
ax += max_dims + num_slices;
|
||||
}
|
||||
return squeeze(src, axes);
|
||||
return mx::squeeze(src, axes);
|
||||
}
|
||||
|
||||
auto mlx_expand_ellipsis(
|
||||
@ -438,9 +438,7 @@ mx::array mlx_get_item(const mx::array& src, const nb::object& obj) {
|
||||
} else if (nb::isinstance<nb::ellipsis>(obj)) {
|
||||
return src;
|
||||
} else if (obj.is_none()) {
|
||||
std::vector<int> s(1, 1);
|
||||
s.insert(s.end(), src.shape().begin(), src.shape().end());
|
||||
return reshape(src, s);
|
||||
return expand_dims(src, 0);
|
||||
} else if (nb::isinstance<nb::list>(obj)) {
|
||||
return mlx_get_item_array(
|
||||
src, array_from_list(nb::cast<nb::list>(obj), {}));
|
||||
@ -474,6 +472,15 @@ mlx_scatter_args_int(
|
||||
{0}};
|
||||
}
|
||||
|
||||
mx::array squeeze_leading_singletons(const mx::array& in) {
|
||||
int s = 0;
|
||||
for (; s < in.ndim() && in.shape(s) == 1; s++)
|
||||
;
|
||||
auto squeeze_axes = std::vector<int>(s);
|
||||
std::iota(squeeze_axes.begin(), squeeze_axes.end(), 0);
|
||||
return mx::squeeze(in, squeeze_axes);
|
||||
}
|
||||
|
||||
std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
|
||||
mlx_scatter_args_array(
|
||||
const mx::array& src,
|
||||
@ -484,16 +491,10 @@ mlx_scatter_args_array(
|
||||
"too many indices for array: array is 0-dimensional");
|
||||
}
|
||||
|
||||
// Remove any leading singleton dimensions from the update
|
||||
int s = 0;
|
||||
for (; s < update.ndim() && update.shape(s) == 1; s++)
|
||||
;
|
||||
auto up_shape =
|
||||
std::vector<int>(update.shape().begin() + s, update.shape().end());
|
||||
auto up = reshape(update, up_shape);
|
||||
auto up = squeeze_leading_singletons(update);
|
||||
|
||||
// The update shape must broadcast with indices.shape + [1] + src.shape[1:]
|
||||
up_shape = indices.shape();
|
||||
auto up_shape = indices.shape();
|
||||
up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end());
|
||||
up = broadcast_to(up, up_shape);
|
||||
up_shape.insert(up_shape.begin() + indices.ndim(), 1);
|
||||
@ -516,12 +517,8 @@ mlx_scatter_args_slice(
|
||||
// If none slice is requested broadcast the update
|
||||
// to the src size and return it.
|
||||
if (is_none_slice(in_slice)) {
|
||||
int s = 0;
|
||||
for (; s < update.ndim() && update.shape(s) == 1; s++)
|
||||
;
|
||||
auto up_shape =
|
||||
std::vector<int>(update.shape().begin() + s, update.shape().end());
|
||||
return {{}, broadcast_to(reshape(update, up_shape), src.shape()), {}};
|
||||
return {
|
||||
{}, broadcast_to(squeeze_leading_singletons(update), src.shape()), {}};
|
||||
}
|
||||
|
||||
int start = 0;
|
||||
@ -534,12 +531,7 @@ mlx_scatter_args_slice(
|
||||
// If simple stride
|
||||
if (stride == 1) {
|
||||
// Squeeze out singleton dims from the start of update
|
||||
int s = 0;
|
||||
for (; s < update.ndim() && update.shape(s) == 1; s++)
|
||||
;
|
||||
auto up_shape =
|
||||
std::vector<int>(update.shape().begin() + s, update.shape().end());
|
||||
auto up = reshape(update, up_shape);
|
||||
auto up = squeeze_leading_singletons(update);
|
||||
|
||||
// Build array to mark start of slice
|
||||
auto idx = mx::array({start}, {1}, mx::uint32);
|
||||
@ -548,7 +540,7 @@ mlx_scatter_args_slice(
|
||||
int slice_size = (end - start);
|
||||
|
||||
// Broadcast update to slice size
|
||||
std::vector<int> up_shape_broadcast = {1, slice_size};
|
||||
mx::Shape up_shape_broadcast = {1, slice_size};
|
||||
up_shape_broadcast.insert(
|
||||
up_shape_broadcast.end(), src.shape().begin() + 1, src.shape().end());
|
||||
|
||||
@ -585,13 +577,7 @@ mlx_scatter_args_nd(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Remove leading singletons dimensions from the update
|
||||
int s = 0;
|
||||
for (; s < update.ndim() && update.shape(s) == 1; s++) {
|
||||
};
|
||||
auto up_shape =
|
||||
std::vector<int>(update.shape().begin() + s, update.shape().end());
|
||||
auto up = reshape(update, up_shape);
|
||||
auto up = squeeze_leading_singletons(update);
|
||||
|
||||
// If no non-None indices return the broadcasted update
|
||||
if (non_none_indices == 0) {
|
||||
@ -703,7 +689,7 @@ mlx_scatter_args_nd(
|
||||
} else if (nb::isinstance<mx::array>(pyidx)) {
|
||||
ax++;
|
||||
auto idx = nb::cast<mx::array>(pyidx);
|
||||
std::vector<int> idx_shape(idx_ndim, 1);
|
||||
mx::Shape idx_shape(idx_ndim, 1);
|
||||
|
||||
// Place the arrays in the correct dimension
|
||||
int st = (!arrays_first) * slice_num + max_dim - idx.ndim();
|
||||
@ -801,17 +787,18 @@ auto mlx_slice_update(
|
||||
|
||||
// Remove extra leading singletons dimensions from the update
|
||||
int s = 0;
|
||||
for (; s < upd.ndim() && upd.shape(s) == 1 && (upd.ndim() - s) > src.ndim();
|
||||
for (; s < static_cast<int>(upd.ndim()) - 1 && upd.shape(s) == 1 &&
|
||||
(upd.ndim() - s) > src.ndim();
|
||||
s++) {
|
||||
};
|
||||
auto up_shape = std::vector<int>(upd.shape().begin() + s, upd.shape().end());
|
||||
up_shape = up_shape.empty() ? std::vector{1} : up_shape;
|
||||
auto up = reshape(upd, up_shape);
|
||||
auto squeeze_axes = std::vector<int>(s);
|
||||
std::iota(squeeze_axes.begin(), squeeze_axes.end(), 0);
|
||||
auto up = mx::squeeze(upd, squeeze_axes);
|
||||
|
||||
// Build slice update params
|
||||
std::vector<int> starts(src.ndim(), 0);
|
||||
std::vector<int> stops = src.shape();
|
||||
std::vector<int> strides(src.ndim(), 1);
|
||||
mx::Shape starts(src.ndim(), 0);
|
||||
mx::Shape stops = src.shape();
|
||||
mx::Shape strides(src.ndim(), 1);
|
||||
|
||||
// If it's just a simple slice, just do a slice update and return
|
||||
if (nb::isinstance<nb::slice>(obj)) {
|
||||
@ -847,7 +834,7 @@ auto mlx_slice_update(
|
||||
}
|
||||
|
||||
// Process entries
|
||||
std::vector<int> up_reshape(src.ndim());
|
||||
mx::Shape up_reshape(src.ndim());
|
||||
int ax = src.ndim() - 1;
|
||||
int up_ax = up.ndim() - 1;
|
||||
for (; ax >= non_none_indices; ax--) {
|
||||
|
Loading…
Reference in New Issue
Block a user