mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Use expand_dims / unflatten / etc in more places (#1696)
* use expand_dims / unflatten in a couple more places * few more * few more * fix
This commit is contained in:
@@ -542,6 +542,9 @@ array squeeze(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (axes.empty()) {
|
||||
return a;
|
||||
}
|
||||
std::set<int> unique_axes;
|
||||
for (auto ax : axes) {
|
||||
unique_axes.insert(ax < 0 ? ax + a.ndim() : ax);
|
||||
@@ -598,6 +601,9 @@ array expand_dims(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (axes.empty()) {
|
||||
return a;
|
||||
}
|
||||
{ // Check for repeats
|
||||
std::set<int> unique_axes(axes.begin(), axes.end());
|
||||
if (unique_axes.size() != axes.size()) {
|
||||
|
||||
Reference in New Issue
Block a user