mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
added atleast *args input support (#710)
* added atleast list(array) input support * function overloading implemented * Refactoring * fixed formatting * removed pos_only
This commit is contained in:
parent
3b661b7394
commit
08226ab491
34
mlx/ops.cpp
34
mlx/ops.cpp
@ -3414,6 +3414,17 @@ array atleast_1d(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return a;
|
||||
}
|
||||
|
||||
std::vector<array> atleast_1d(
|
||||
const std::vector<array>& arrays,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
std::vector<array> out;
|
||||
out.reserve(arrays.size());
|
||||
for (const auto& a : arrays) {
|
||||
out.push_back(atleast_1d(a, s));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
array atleast_2d(const array& a, StreamOrDevice s /* = {} */) {
|
||||
switch (a.ndim()) {
|
||||
case 0:
|
||||
@ -3425,6 +3436,17 @@ array atleast_2d(const array& a, StreamOrDevice s /* = {} */) {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> atleast_2d(
|
||||
const std::vector<array>& arrays,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
std::vector<array> out;
|
||||
out.reserve(arrays.size());
|
||||
for (const auto& a : arrays) {
|
||||
out.push_back(atleast_2d(a, s));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
array atleast_3d(const array& a, StreamOrDevice s /* = {} */) {
|
||||
switch (a.ndim()) {
|
||||
case 0:
|
||||
@ -3437,4 +3459,16 @@ array atleast_3d(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return a;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> atleast_3d(
|
||||
const std::vector<array>& arrays,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
std::vector<array> out;
|
||||
out.reserve(arrays.size());
|
||||
for (const auto& a : arrays) {
|
||||
out.push_back(atleast_3d(a, s));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -1123,7 +1123,16 @@ std::vector<array> depends(
|
||||
|
||||
/** convert an array to an atleast ndim array */
|
||||
array atleast_1d(const array& a, StreamOrDevice s = {});
|
||||
std::vector<array> atleast_1d(
|
||||
const std::vector<array>& a,
|
||||
StreamOrDevice s = {});
|
||||
array atleast_2d(const array& a, StreamOrDevice s = {});
|
||||
std::vector<array> atleast_2d(
|
||||
const std::vector<array>& a,
|
||||
StreamOrDevice s = {});
|
||||
array atleast_3d(const array& a, StreamOrDevice s = {});
|
||||
std::vector<array> atleast_3d(
|
||||
const std::vector<array>& a,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -3638,62 +3638,69 @@ void init_ops(py::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"atleast_1d",
|
||||
&atleast_1d,
|
||||
"a"_a,
|
||||
py::pos_only(),
|
||||
[](const py::args& arys, StreamOrDevice s) -> py::object {
|
||||
if (arys.size() == 1) {
|
||||
return py::cast(atleast_1d(arys[0].cast<array>(), s));
|
||||
}
|
||||
return py::cast(atleast_1d(arys.cast<std::vector<array>>(), s));
|
||||
},
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
atleast_1d(a: array, stream: Union[None, Stream, Device] = None) -> array
|
||||
atleast_1d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]
|
||||
|
||||
Convert array to have at least one dimension.
|
||||
Convert all arrays to have at least one dimension.
|
||||
|
||||
args:
|
||||
a (array): Input array
|
||||
Args:
|
||||
*arys: Input arrays.
|
||||
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.
|
||||
|
||||
Returns:
|
||||
array: An array with at least one dimension.
|
||||
|
||||
array or list(array): An array or list of arrays with at least one dimension.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"atleast_2d",
|
||||
&atleast_2d,
|
||||
"a"_a,
|
||||
py::pos_only(),
|
||||
[](const py::args& arys, StreamOrDevice s) -> py::object {
|
||||
if (arys.size() == 1) {
|
||||
return py::cast(atleast_2d(arys[0].cast<array>(), s));
|
||||
}
|
||||
return py::cast(atleast_2d(arys.cast<std::vector<array>>(), s));
|
||||
},
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
atleast_2d(a: array, stream: Union[None, Stream, Device] = None) -> array
|
||||
atleast_2d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]
|
||||
|
||||
Convert array to have at least two dimensions.
|
||||
Convert all arrays to have at least two dimensions.
|
||||
|
||||
args:
|
||||
a (array): Input array
|
||||
Args:
|
||||
*arys: Input arrays.
|
||||
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.
|
||||
|
||||
Returns:
|
||||
array: An array with at least two dimensions.
|
||||
|
||||
array or list(array): An array or list of arrays with at least two dimensions.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"atleast_3d",
|
||||
&atleast_3d,
|
||||
"a"_a,
|
||||
py::pos_only(),
|
||||
[](const py::args& arys, StreamOrDevice s) -> py::object {
|
||||
if (arys.size() == 1) {
|
||||
return py::cast(atleast_3d(arys[0].cast<array>(), s));
|
||||
}
|
||||
return py::cast(atleast_3d(arys.cast<std::vector<array>>(), s));
|
||||
},
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
atleast_3d(a: array, stream: Union[None, Stream, Device] = None) -> array
|
||||
atleast_3d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]
|
||||
|
||||
Convert array to have at least three dimensions.
|
||||
Convert all arrays to have at least three dimensions.
|
||||
|
||||
args:
|
||||
a (array): Input array
|
||||
Args:
|
||||
*arys: Input arrays.
|
||||
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.
|
||||
|
||||
Returns:
|
||||
array: An array with at least three dimensions.
|
||||
|
||||
array or list(array): An array or list of arrays with at least three dimensions.
|
||||
)pbdoc");
|
||||
}
|
||||
|
@ -1932,12 +1932,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
[[[[1]], [[2]], [[3]]]],
|
||||
]
|
||||
|
||||
for array in arrays:
|
||||
mx_arrays = [mx.atleast_1d(mx.array(x)) for x in arrays]
|
||||
atleast_arrays = mx.atleast_1d(*mx_arrays)
|
||||
|
||||
for i, array in enumerate(arrays):
|
||||
mx_res = mx.atleast_1d(mx.array(array))
|
||||
np_res = np.atleast_1d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
|
||||
def test_atleast_2d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
@ -1962,12 +1966,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
[[[[1]], [[2]], [[3]]]],
|
||||
]
|
||||
|
||||
for array in arrays:
|
||||
mx_arrays = [mx.atleast_2d(mx.array(x)) for x in arrays]
|
||||
atleast_arrays = mx.atleast_2d(*mx_arrays)
|
||||
|
||||
for i, array in enumerate(arrays):
|
||||
mx_res = mx.atleast_2d(mx.array(array))
|
||||
np_res = np.atleast_2d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
|
||||
def test_atleast_3d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
@ -1992,12 +2000,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
[[[[1]], [[2]], [[3]]]],
|
||||
]
|
||||
|
||||
for array in arrays:
|
||||
mx_arrays = [mx.atleast_3d(mx.array(x)) for x in arrays]
|
||||
atleast_arrays = mx.atleast_3d(*mx_arrays)
|
||||
|
||||
for i, array in enumerate(arrays):
|
||||
mx_res = mx.atleast_3d(mx.array(array))
|
||||
np_res = np.atleast_3d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -2787,6 +2787,19 @@ TEST_CASE("test atleast_1d") {
|
||||
CHECK_EQ(out.shape(), std::vector<int>{3, 1});
|
||||
}
|
||||
|
||||
TEST_CASE("test atleast_1d vector") {
|
||||
auto x = std::vector<array>{
|
||||
array(1), array({1, 2, 3}, {3}), array({1, 2, 3}, {3, 1})};
|
||||
auto out = atleast_1d(x);
|
||||
CHECK_EQ(out.size(), 3);
|
||||
CHECK_EQ(out[0].ndim(), 1);
|
||||
CHECK_EQ(out[0].shape(), std::vector<int>{1});
|
||||
CHECK_EQ(out[1].ndim(), 1);
|
||||
CHECK_EQ(out[1].shape(), std::vector<int>{3});
|
||||
CHECK_EQ(out[2].ndim(), 2);
|
||||
CHECK_EQ(out[2].shape(), std::vector<int>{3, 1});
|
||||
}
|
||||
|
||||
TEST_CASE("test atleast_2d") {
|
||||
auto x = array(1);
|
||||
auto out = atleast_2d(x);
|
||||
@ -2804,6 +2817,19 @@ TEST_CASE("test atleast_2d") {
|
||||
CHECK_EQ(out.shape(), std::vector<int>{3, 1});
|
||||
}
|
||||
|
||||
TEST_CASE("test atleast_2d vector") {
|
||||
auto x = std::vector<array>{
|
||||
array(1), array({1, 2, 3}, {3}), array({1, 2, 3}, {3, 1})};
|
||||
auto out = atleast_2d(x);
|
||||
CHECK_EQ(out.size(), 3);
|
||||
CHECK_EQ(out[0].ndim(), 2);
|
||||
CHECK_EQ(out[0].shape(), std::vector<int>{1, 1});
|
||||
CHECK_EQ(out[1].ndim(), 2);
|
||||
CHECK_EQ(out[1].shape(), std::vector<int>{1, 3});
|
||||
CHECK_EQ(out[2].ndim(), 2);
|
||||
CHECK_EQ(out[2].shape(), std::vector<int>{3, 1});
|
||||
}
|
||||
|
||||
TEST_CASE("test atleast_3d") {
|
||||
auto x = array(1);
|
||||
auto out = atleast_3d(x);
|
||||
@ -2820,3 +2846,16 @@ TEST_CASE("test atleast_3d") {
|
||||
CHECK_EQ(out.ndim(), 3);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{3, 1, 1});
|
||||
}
|
||||
|
||||
TEST_CASE("test atleast_3d vector") {
|
||||
auto x = std::vector<array>{
|
||||
array(1), array({1, 2, 3}, {3}), array({1, 2, 3}, {3, 1})};
|
||||
auto out = atleast_3d(x);
|
||||
CHECK_EQ(out.size(), 3);
|
||||
CHECK_EQ(out[0].ndim(), 3);
|
||||
CHECK_EQ(out[0].shape(), std::vector<int>{1, 1, 1});
|
||||
CHECK_EQ(out[1].ndim(), 3);
|
||||
CHECK_EQ(out[1].shape(), std::vector<int>{1, 3, 1});
|
||||
CHECK_EQ(out[2].ndim(), 3);
|
||||
CHECK_EQ(out[2].shape(), std::vector<int>{3, 1, 1});
|
||||
}
|
Loading…
Reference in New Issue
Block a user