mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
added atleast *args input support (#710)
* added atleast list(array) input support * function overloading implemented * Refactoring * fixed formatting * removed pos_only
This commit is contained in:

committed by
GitHub

parent
3b661b7394
commit
08226ab491
34
mlx/ops.cpp
34
mlx/ops.cpp
@@ -3414,6 +3414,17 @@ array atleast_1d(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return a;
|
||||
}
|
||||
|
||||
std::vector<array> atleast_1d(
|
||||
const std::vector<array>& arrays,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
std::vector<array> out;
|
||||
out.reserve(arrays.size());
|
||||
for (const auto& a : arrays) {
|
||||
out.push_back(atleast_1d(a, s));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
array atleast_2d(const array& a, StreamOrDevice s /* = {} */) {
|
||||
switch (a.ndim()) {
|
||||
case 0:
|
||||
@@ -3425,6 +3436,17 @@ array atleast_2d(const array& a, StreamOrDevice s /* = {} */) {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> atleast_2d(
|
||||
const std::vector<array>& arrays,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
std::vector<array> out;
|
||||
out.reserve(arrays.size());
|
||||
for (const auto& a : arrays) {
|
||||
out.push_back(atleast_2d(a, s));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
array atleast_3d(const array& a, StreamOrDevice s /* = {} */) {
|
||||
switch (a.ndim()) {
|
||||
case 0:
|
||||
@@ -3437,4 +3459,16 @@ array atleast_3d(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return a;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> atleast_3d(
|
||||
const std::vector<array>& arrays,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
std::vector<array> out;
|
||||
out.reserve(arrays.size());
|
||||
for (const auto& a : arrays) {
|
||||
out.push_back(atleast_3d(a, s));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1123,7 +1123,16 @@ std::vector<array> depends(
|
||||
|
||||
/** convert an array to an atleast ndim array */
|
||||
array atleast_1d(const array& a, StreamOrDevice s = {});
|
||||
std::vector<array> atleast_1d(
|
||||
const std::vector<array>& a,
|
||||
StreamOrDevice s = {});
|
||||
array atleast_2d(const array& a, StreamOrDevice s = {});
|
||||
std::vector<array> atleast_2d(
|
||||
const std::vector<array>& a,
|
||||
StreamOrDevice s = {});
|
||||
array atleast_3d(const array& a, StreamOrDevice s = {});
|
||||
std::vector<array> atleast_3d(
|
||||
const std::vector<array>& a,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
} // namespace mlx::core
|
||||
|
Reference in New Issue
Block a user