added atleast *args input support (#710)

* added atleast list(array) input support

* function overloading implemented

* Refactoring

* fixed formatting

* removed pos_only
This commit is contained in:
Hinrik Snær Guðmundsson
2024-02-26 14:17:59 -05:00
committed by GitHub
parent 3b661b7394
commit 08226ab491
5 changed files with 131 additions and 30 deletions

View File

@@ -3414,6 +3414,17 @@ array atleast_1d(const array& a, StreamOrDevice s /* = {} */) {
return a;
}
std::vector<array> atleast_1d(
const std::vector<array>& arrays,
StreamOrDevice s /* = {} */) {
std::vector<array> out;
out.reserve(arrays.size());
for (const auto& a : arrays) {
out.push_back(atleast_1d(a, s));
}
return out;
}
array atleast_2d(const array& a, StreamOrDevice s /* = {} */) {
switch (a.ndim()) {
case 0:
@@ -3425,6 +3436,17 @@ array atleast_2d(const array& a, StreamOrDevice s /* = {} */) {
}
}
std::vector<array> atleast_2d(
const std::vector<array>& arrays,
StreamOrDevice s /* = {} */) {
std::vector<array> out;
out.reserve(arrays.size());
for (const auto& a : arrays) {
out.push_back(atleast_2d(a, s));
}
return out;
}
array atleast_3d(const array& a, StreamOrDevice s /* = {} */) {
switch (a.ndim()) {
case 0:
@@ -3437,4 +3459,16 @@ array atleast_3d(const array& a, StreamOrDevice s /* = {} */) {
return a;
}
}
std::vector<array> atleast_3d(
const std::vector<array>& arrays,
StreamOrDevice s /* = {} */) {
std::vector<array> out;
out.reserve(arrays.size());
for (const auto& a : arrays) {
out.push_back(atleast_3d(a, s));
}
return out;
}
} // namespace mlx::core

View File

@@ -1123,7 +1123,16 @@ std::vector<array> depends(
/** convert an array to an atleast ndim array */
array atleast_1d(const array& a, StreamOrDevice s = {});
std::vector<array> atleast_1d(
const std::vector<array>& a,
StreamOrDevice s = {});
array atleast_2d(const array& a, StreamOrDevice s = {});
std::vector<array> atleast_2d(
const std::vector<array>& a,
StreamOrDevice s = {});
array atleast_3d(const array& a, StreamOrDevice s = {});
std::vector<array> atleast_3d(
const std::vector<array>& a,
StreamOrDevice s = {});
} // namespace mlx::core