Use expand_dims / unflatten / etc in more places (#1696)

* use expand_dims / unflatten in a couple more places

* few more

* few more

* fix
This commit is contained in:
Awni Hannun
2024-12-12 17:00:44 -08:00
committed by GitHub
parent 9111999af3
commit 50f3535693
3 changed files with 43 additions and 45 deletions

View File

@@ -176,7 +176,7 @@ mx::array mlx_gather_nd(
for (auto& ax : axes) {
ax += max_dims + num_slices;
}
return squeeze(src, axes);
return mx::squeeze(src, axes);
}
auto mlx_expand_ellipsis(
@@ -438,9 +438,7 @@ mx::array mlx_get_item(const mx::array& src, const nb::object& obj) {
} else if (nb::isinstance<nb::ellipsis>(obj)) {
return src;
} else if (obj.is_none()) {
std::vector<int> s(1, 1);
s.insert(s.end(), src.shape().begin(), src.shape().end());
return reshape(src, s);
return expand_dims(src, 0);
} else if (nb::isinstance<nb::list>(obj)) {
return mlx_get_item_array(
src, array_from_list(nb::cast<nb::list>(obj), {}));
@@ -474,6 +472,15 @@ mlx_scatter_args_int(
{0}};
}
mx::array squeeze_leading_singletons(const mx::array& in) {
int s = 0;
for (; s < in.ndim() && in.shape(s) == 1; s++)
;
auto squeeze_axes = std::vector<int>(s);
std::iota(squeeze_axes.begin(), squeeze_axes.end(), 0);
return mx::squeeze(in, squeeze_axes);
}
std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
mlx_scatter_args_array(
const mx::array& src,
@@ -484,16 +491,10 @@ mlx_scatter_args_array(
"too many indices for array: array is 0-dimensional");
}
// Remove any leading singleton dimensions from the update
int s = 0;
for (; s < update.ndim() && update.shape(s) == 1; s++)
;
auto up_shape =
std::vector<int>(update.shape().begin() + s, update.shape().end());
auto up = reshape(update, up_shape);
auto up = squeeze_leading_singletons(update);
// The update shape must broadcast with indices.shape + [1] + src.shape[1:]
up_shape = indices.shape();
auto up_shape = indices.shape();
up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end());
up = broadcast_to(up, up_shape);
up_shape.insert(up_shape.begin() + indices.ndim(), 1);
@@ -516,12 +517,8 @@ mlx_scatter_args_slice(
// If none slice is requested broadcast the update
// to the src size and return it.
if (is_none_slice(in_slice)) {
int s = 0;
for (; s < update.ndim() && update.shape(s) == 1; s++)
;
auto up_shape =
std::vector<int>(update.shape().begin() + s, update.shape().end());
return {{}, broadcast_to(reshape(update, up_shape), src.shape()), {}};
return {
{}, broadcast_to(squeeze_leading_singletons(update), src.shape()), {}};
}
int start = 0;
@@ -534,12 +531,7 @@ mlx_scatter_args_slice(
// If simple stride
if (stride == 1) {
// Squeeze out singleton dims from the start of update
int s = 0;
for (; s < update.ndim() && update.shape(s) == 1; s++)
;
auto up_shape =
std::vector<int>(update.shape().begin() + s, update.shape().end());
auto up = reshape(update, up_shape);
auto up = squeeze_leading_singletons(update);
// Build array to mark start of slice
auto idx = mx::array({start}, {1}, mx::uint32);
@@ -548,7 +540,7 @@ mlx_scatter_args_slice(
int slice_size = (end - start);
// Broadcast update to slice size
std::vector<int> up_shape_broadcast = {1, slice_size};
mx::Shape up_shape_broadcast = {1, slice_size};
up_shape_broadcast.insert(
up_shape_broadcast.end(), src.shape().begin() + 1, src.shape().end());
@@ -585,13 +577,7 @@ mlx_scatter_args_nd(
throw std::invalid_argument(msg.str());
}
// Remove leading singletons dimensions from the update
int s = 0;
for (; s < update.ndim() && update.shape(s) == 1; s++) {
};
auto up_shape =
std::vector<int>(update.shape().begin() + s, update.shape().end());
auto up = reshape(update, up_shape);
auto up = squeeze_leading_singletons(update);
// If no non-None indices return the broadcasted update
if (non_none_indices == 0) {
@@ -703,7 +689,7 @@ mlx_scatter_args_nd(
} else if (nb::isinstance<mx::array>(pyidx)) {
ax++;
auto idx = nb::cast<mx::array>(pyidx);
std::vector<int> idx_shape(idx_ndim, 1);
mx::Shape idx_shape(idx_ndim, 1);
// Place the arrays in the correct dimension
int st = (!arrays_first) * slice_num + max_dim - idx.ndim();
@@ -801,17 +787,18 @@ auto mlx_slice_update(
// Remove extra leading singletons dimensions from the update
int s = 0;
for (; s < upd.ndim() && upd.shape(s) == 1 && (upd.ndim() - s) > src.ndim();
for (; s < static_cast<int>(upd.ndim()) - 1 && upd.shape(s) == 1 &&
(upd.ndim() - s) > src.ndim();
s++) {
};
auto up_shape = std::vector<int>(upd.shape().begin() + s, upd.shape().end());
up_shape = up_shape.empty() ? std::vector{1} : up_shape;
auto up = reshape(upd, up_shape);
auto squeeze_axes = std::vector<int>(s);
std::iota(squeeze_axes.begin(), squeeze_axes.end(), 0);
auto up = mx::squeeze(upd, squeeze_axes);
// Build slice update params
std::vector<int> starts(src.ndim(), 0);
std::vector<int> stops = src.shape();
std::vector<int> strides(src.ndim(), 1);
mx::Shape starts(src.ndim(), 0);
mx::Shape stops = src.shape();
mx::Shape strides(src.ndim(), 1);
// If it's just a simple slice, just do a slice update and return
if (nb::isinstance<nb::slice>(obj)) {
@@ -847,7 +834,7 @@ auto mlx_slice_update(
}
// Process entries
std::vector<int> up_reshape(src.ndim());
mx::Shape up_reshape(src.ndim());
int ax = src.ndim() - 1;
int up_ax = up.ndim() - 1;
for (; ax >= non_none_indices; ax--) {