mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 23:24:41 +08:00
Use expand_dims / unflatten / etc in more places (#1696)
* use expand_dims / unflatten in a couple more places * few more * few more * fix
This commit is contained in:
@@ -176,7 +176,7 @@ mx::array mlx_gather_nd(
|
||||
for (auto& ax : axes) {
|
||||
ax += max_dims + num_slices;
|
||||
}
|
||||
return squeeze(src, axes);
|
||||
return mx::squeeze(src, axes);
|
||||
}
|
||||
|
||||
auto mlx_expand_ellipsis(
|
||||
@@ -438,9 +438,7 @@ mx::array mlx_get_item(const mx::array& src, const nb::object& obj) {
|
||||
} else if (nb::isinstance<nb::ellipsis>(obj)) {
|
||||
return src;
|
||||
} else if (obj.is_none()) {
|
||||
std::vector<int> s(1, 1);
|
||||
s.insert(s.end(), src.shape().begin(), src.shape().end());
|
||||
return reshape(src, s);
|
||||
return expand_dims(src, 0);
|
||||
} else if (nb::isinstance<nb::list>(obj)) {
|
||||
return mlx_get_item_array(
|
||||
src, array_from_list(nb::cast<nb::list>(obj), {}));
|
||||
@@ -474,6 +472,15 @@ mlx_scatter_args_int(
|
||||
{0}};
|
||||
}
|
||||
|
||||
mx::array squeeze_leading_singletons(const mx::array& in) {
|
||||
int s = 0;
|
||||
for (; s < in.ndim() && in.shape(s) == 1; s++)
|
||||
;
|
||||
auto squeeze_axes = std::vector<int>(s);
|
||||
std::iota(squeeze_axes.begin(), squeeze_axes.end(), 0);
|
||||
return mx::squeeze(in, squeeze_axes);
|
||||
}
|
||||
|
||||
std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
|
||||
mlx_scatter_args_array(
|
||||
const mx::array& src,
|
||||
@@ -484,16 +491,10 @@ mlx_scatter_args_array(
|
||||
"too many indices for array: array is 0-dimensional");
|
||||
}
|
||||
|
||||
// Remove any leading singleton dimensions from the update
|
||||
int s = 0;
|
||||
for (; s < update.ndim() && update.shape(s) == 1; s++)
|
||||
;
|
||||
auto up_shape =
|
||||
std::vector<int>(update.shape().begin() + s, update.shape().end());
|
||||
auto up = reshape(update, up_shape);
|
||||
auto up = squeeze_leading_singletons(update);
|
||||
|
||||
// The update shape must broadcast with indices.shape + [1] + src.shape[1:]
|
||||
up_shape = indices.shape();
|
||||
auto up_shape = indices.shape();
|
||||
up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end());
|
||||
up = broadcast_to(up, up_shape);
|
||||
up_shape.insert(up_shape.begin() + indices.ndim(), 1);
|
||||
@@ -516,12 +517,8 @@ mlx_scatter_args_slice(
|
||||
// If none slice is requested broadcast the update
|
||||
// to the src size and return it.
|
||||
if (is_none_slice(in_slice)) {
|
||||
int s = 0;
|
||||
for (; s < update.ndim() && update.shape(s) == 1; s++)
|
||||
;
|
||||
auto up_shape =
|
||||
std::vector<int>(update.shape().begin() + s, update.shape().end());
|
||||
return {{}, broadcast_to(reshape(update, up_shape), src.shape()), {}};
|
||||
return {
|
||||
{}, broadcast_to(squeeze_leading_singletons(update), src.shape()), {}};
|
||||
}
|
||||
|
||||
int start = 0;
|
||||
@@ -534,12 +531,7 @@ mlx_scatter_args_slice(
|
||||
// If simple stride
|
||||
if (stride == 1) {
|
||||
// Squeeze out singleton dims from the start of update
|
||||
int s = 0;
|
||||
for (; s < update.ndim() && update.shape(s) == 1; s++)
|
||||
;
|
||||
auto up_shape =
|
||||
std::vector<int>(update.shape().begin() + s, update.shape().end());
|
||||
auto up = reshape(update, up_shape);
|
||||
auto up = squeeze_leading_singletons(update);
|
||||
|
||||
// Build array to mark start of slice
|
||||
auto idx = mx::array({start}, {1}, mx::uint32);
|
||||
@@ -548,7 +540,7 @@ mlx_scatter_args_slice(
|
||||
int slice_size = (end - start);
|
||||
|
||||
// Broadcast update to slice size
|
||||
std::vector<int> up_shape_broadcast = {1, slice_size};
|
||||
mx::Shape up_shape_broadcast = {1, slice_size};
|
||||
up_shape_broadcast.insert(
|
||||
up_shape_broadcast.end(), src.shape().begin() + 1, src.shape().end());
|
||||
|
||||
@@ -585,13 +577,7 @@ mlx_scatter_args_nd(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Remove leading singletons dimensions from the update
|
||||
int s = 0;
|
||||
for (; s < update.ndim() && update.shape(s) == 1; s++) {
|
||||
};
|
||||
auto up_shape =
|
||||
std::vector<int>(update.shape().begin() + s, update.shape().end());
|
||||
auto up = reshape(update, up_shape);
|
||||
auto up = squeeze_leading_singletons(update);
|
||||
|
||||
// If no non-None indices return the broadcasted update
|
||||
if (non_none_indices == 0) {
|
||||
@@ -703,7 +689,7 @@ mlx_scatter_args_nd(
|
||||
} else if (nb::isinstance<mx::array>(pyidx)) {
|
||||
ax++;
|
||||
auto idx = nb::cast<mx::array>(pyidx);
|
||||
std::vector<int> idx_shape(idx_ndim, 1);
|
||||
mx::Shape idx_shape(idx_ndim, 1);
|
||||
|
||||
// Place the arrays in the correct dimension
|
||||
int st = (!arrays_first) * slice_num + max_dim - idx.ndim();
|
||||
@@ -801,17 +787,18 @@ auto mlx_slice_update(
|
||||
|
||||
// Remove extra leading singletons dimensions from the update
|
||||
int s = 0;
|
||||
for (; s < upd.ndim() && upd.shape(s) == 1 && (upd.ndim() - s) > src.ndim();
|
||||
for (; s < static_cast<int>(upd.ndim()) - 1 && upd.shape(s) == 1 &&
|
||||
(upd.ndim() - s) > src.ndim();
|
||||
s++) {
|
||||
};
|
||||
auto up_shape = std::vector<int>(upd.shape().begin() + s, upd.shape().end());
|
||||
up_shape = up_shape.empty() ? std::vector{1} : up_shape;
|
||||
auto up = reshape(upd, up_shape);
|
||||
auto squeeze_axes = std::vector<int>(s);
|
||||
std::iota(squeeze_axes.begin(), squeeze_axes.end(), 0);
|
||||
auto up = mx::squeeze(upd, squeeze_axes);
|
||||
|
||||
// Build slice update params
|
||||
std::vector<int> starts(src.ndim(), 0);
|
||||
std::vector<int> stops = src.shape();
|
||||
std::vector<int> strides(src.ndim(), 1);
|
||||
mx::Shape starts(src.ndim(), 0);
|
||||
mx::Shape stops = src.shape();
|
||||
mx::Shape strides(src.ndim(), 1);
|
||||
|
||||
// If it's just a simple slice, just do a slice update and return
|
||||
if (nb::isinstance<nb::slice>(obj)) {
|
||||
@@ -847,7 +834,7 @@ auto mlx_slice_update(
|
||||
}
|
||||
|
||||
// Process entries
|
||||
std::vector<int> up_reshape(src.ndim());
|
||||
mx::Shape up_reshape(src.ndim());
|
||||
int ax = src.ndim() - 1;
|
||||
int up_ax = up.ndim() - 1;
|
||||
for (; ax >= non_none_indices; ax--) {
|
||||
|
Reference in New Issue
Block a user