mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
Added support for atleast_1d, atleast_2d, atleast_3d (#694)
This commit is contained in:

committed by
GitHub

parent
e1bdf6a8d9
commit
f883fcede0
30
mlx/ops.cpp
30
mlx/ops.cpp
@@ -3381,4 +3381,34 @@ std::vector<array> depends(
|
||||
shapes, dtypes, std::make_shared<Depends>(to_stream(s)), all_inputs);
|
||||
}
|
||||
|
||||
array atleast_1d(const array& a, StreamOrDevice s /* = {} */) {
|
||||
if (a.ndim() == 0) {
|
||||
return reshape(a, {1}, s);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
array atleast_2d(const array& a, StreamOrDevice s /* = {} */) {
|
||||
switch (a.ndim()) {
|
||||
case 0:
|
||||
return reshape(a, {1, 1}, s);
|
||||
case 1:
|
||||
return reshape(a, {1, static_cast<int>(a.size())}, s);
|
||||
default:
|
||||
return a;
|
||||
}
|
||||
}
|
||||
|
||||
array atleast_3d(const array& a, StreamOrDevice s /* = {} */) {
|
||||
switch (a.ndim()) {
|
||||
case 0:
|
||||
return reshape(a, {1, 1, 1}, s);
|
||||
case 1:
|
||||
return reshape(a, {1, static_cast<int>(a.size()), 1}, s);
|
||||
case 2:
|
||||
return reshape(a, {a.shape(0), a.shape(1), 1}, s);
|
||||
default:
|
||||
return a;
|
||||
}
|
||||
}
|
||||
} // namespace mlx::core
|
||||
|
@@ -1121,4 +1121,9 @@ std::vector<array> depends(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& dependencies);
|
||||
|
||||
/** convert an array to an atleast ndim array */
|
||||
array atleast_1d(const array& a, StreamOrDevice s = {});
|
||||
array atleast_2d(const array& a, StreamOrDevice s = {});
|
||||
array atleast_3d(const array& a, StreamOrDevice s = {});
|
||||
|
||||
} // namespace mlx::core
|
||||
|
Reference in New Issue
Block a user