Added support for atleast_1d, atleast_2d, atleast_3d (#694)

This commit is contained in:
Hinrik Snær Guðmundsson
2024-02-19 12:40:52 -05:00
committed by GitHub
parent e1bdf6a8d9
commit f883fcede0
7 changed files with 241 additions and 1 deletions

View File

@@ -3636,4 +3636,64 @@ void init_ops(py::module_& m) {
Returns:
array: The extracted diagonal or the constructed diagonal matrix.
)pbdoc");
m.def(
"atleast_1d",
&atleast_1d,
"a"_a,
py::pos_only(),
py::kw_only(),
"stream"_a = none,
R"pbdoc(
atleast_1d(a: array, stream: Union[None, Stream, Device] = None) -> array
Convert array to have at least one dimension.
args:
a (array): Input array
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.
Returns:
array: An array with at least one dimension.
)pbdoc");
m.def(
"atleast_2d",
&atleast_2d,
"a"_a,
py::pos_only(),
py::kw_only(),
"stream"_a = none,
R"pbdoc(
atleast_2d(a: array, stream: Union[None, Stream, Device] = None) -> array
Convert array to have at least two dimensions.
args:
a (array): Input array
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.
Returns:
array: An array with at least two dimensions.
)pbdoc");
m.def(
"atleast_3d",
&atleast_3d,
"a"_a,
py::pos_only(),
py::kw_only(),
"stream"_a = none,
R"pbdoc(
atleast_3d(a: array, stream: Union[None, Stream, Device] = None) -> array
Convert array to have at least three dimensions.
args:
a (array): Input array
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.
Returns:
array: An array with at least three dimensions.
)pbdoc");
}

View File

@@ -1883,6 +1883,96 @@ class TestOps(mlx_tests.MLXTestCase):
expected = mx.array(np.diag(x, k=-1))
self.assertTrue(mx.array_equal(result, expected))
def test_atleast_1d(self):
def compare_nested_lists(x, y):
if isinstance(x, list) and isinstance(y, list):
if len(x) != len(y):
return False
for i in range(len(x)):
if not compare_nested_lists(x[i], y[i]):
return False
return True
else:
return x == y
# Test 1D input
arrays = [
[1],
[1, 2, 3],
[1, 2, 3, 4],
[[1], [2], [3]],
[[1, 2], [3, 4]],
[[1, 2, 3], [4, 5, 6]],
[[[[1]], [[2]], [[3]]]],
]
for array in arrays:
mx_res = mx.atleast_1d(mx.array(array))
np_res = np.atleast_1d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim)
def test_atleast_2d(self):
def compare_nested_lists(x, y):
if isinstance(x, list) and isinstance(y, list):
if len(x) != len(y):
return False
for i in range(len(x)):
if not compare_nested_lists(x[i], y[i]):
return False
return True
else:
return x == y
# Test 1D input
arrays = [
[1],
[1, 2, 3],
[1, 2, 3, 4],
[[1], [2], [3]],
[[1, 2], [3, 4]],
[[1, 2, 3], [4, 5, 6]],
[[[[1]], [[2]], [[3]]]],
]
for array in arrays:
mx_res = mx.atleast_2d(mx.array(array))
np_res = np.atleast_2d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim)
def test_atleast_3d(self):
def compare_nested_lists(x, y):
if isinstance(x, list) and isinstance(y, list):
if len(x) != len(y):
return False
for i in range(len(x)):
if not compare_nested_lists(x[i], y[i]):
return False
return True
else:
return x == y
# Test 1D input
arrays = [
[1],
[1, 2, 3],
[1, 2, 3, 4],
[[1], [2], [3]],
[[1, 2], [3, 4]],
[[1, 2, 3], [4, 5, 6]],
[[[[1]], [[2]], [[3]]]],
]
for array in arrays:
mx_res = mx.atleast_3d(mx.array(array))
np_res = np.atleast_3d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim)
if __name__ == "__main__":
unittest.main()