Use expand_dims / unflatten / etc in more places (#1696)

* use expand_dims / unflatten in a couple more places

* few more

* few more

* fix
This commit is contained in:
Awni Hannun
2024-12-12 17:00:44 -08:00
committed by GitHub
parent 9111999af3
commit 50f3535693
3 changed files with 43 additions and 45 deletions

View File

@@ -542,6 +542,9 @@ array squeeze(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s /* = {} */) {
if (axes.empty()) {
return a;
}
std::set<int> unique_axes;
for (auto ax : axes) {
unique_axes.insert(ax < 0 ? ax + a.ndim() : ax);
@@ -598,6 +601,9 @@ array expand_dims(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s /* = {} */) {
if (axes.empty()) {
return a;
}
{ // Check for repeats
std::set<int> unique_axes(axes.begin(), axes.end());
if (unique_axes.size() != axes.size()) {