Added support for atleast_1d, atleast_2d, atleast_3d (#694)

This commit is contained in:
Hinrik Snær Guðmundsson
2024-02-19 12:40:52 -05:00
committed by GitHub
parent e1bdf6a8d9
commit f883fcede0
7 changed files with 241 additions and 1 deletions

View File

@@ -3381,4 +3381,34 @@ std::vector<array> depends(
shapes, dtypes, std::make_shared<Depends>(to_stream(s)), all_inputs);
}
array atleast_1d(const array& a, StreamOrDevice s /* = {} */) {
if (a.ndim() == 0) {
return reshape(a, {1}, s);
}
return a;
}
array atleast_2d(const array& a, StreamOrDevice s /* = {} */) {
switch (a.ndim()) {
case 0:
return reshape(a, {1, 1}, s);
case 1:
return reshape(a, {1, static_cast<int>(a.size())}, s);
default:
return a;
}
}
array atleast_3d(const array& a, StreamOrDevice s /* = {} */) {
switch (a.ndim()) {
case 0:
return reshape(a, {1, 1, 1}, s);
case 1:
return reshape(a, {1, static_cast<int>(a.size()), 1}, s);
case 2:
return reshape(a, {a.shape(0), a.shape(1), 1}, s);
default:
return a;
}
}
} // namespace mlx::core

View File

@@ -1121,4 +1121,9 @@ std::vector<array> depends(
const std::vector<array>& inputs,
const std::vector<array>& dependencies);
/** convert an array to an atleast ndim array */
array atleast_1d(const array& a, StreamOrDevice s = {});
array atleast_2d(const array& a, StreamOrDevice s = {});
array atleast_3d(const array& a, StreamOrDevice s = {});
} // namespace mlx::core