mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Added support for atleast_1d, atleast_2d, atleast_3d (#694)
This commit is contained in:

committed by
GitHub

parent
e1bdf6a8d9
commit
f883fcede0
@@ -1883,6 +1883,96 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
expected = mx.array(np.diag(x, k=-1))
|
||||
self.assertTrue(mx.array_equal(result, expected))
|
||||
|
||||
def test_atleast_1d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
if isinstance(x, list) and isinstance(y, list):
|
||||
if len(x) != len(y):
|
||||
return False
|
||||
for i in range(len(x)):
|
||||
if not compare_nested_lists(x[i], y[i]):
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return x == y
|
||||
|
||||
# Test 1D input
|
||||
arrays = [
|
||||
[1],
|
||||
[1, 2, 3],
|
||||
[1, 2, 3, 4],
|
||||
[[1], [2], [3]],
|
||||
[[1, 2], [3, 4]],
|
||||
[[1, 2, 3], [4, 5, 6]],
|
||||
[[[[1]], [[2]], [[3]]]],
|
||||
]
|
||||
|
||||
for array in arrays:
|
||||
mx_res = mx.atleast_1d(mx.array(array))
|
||||
np_res = np.atleast_1d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
|
||||
def test_atleast_2d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
if isinstance(x, list) and isinstance(y, list):
|
||||
if len(x) != len(y):
|
||||
return False
|
||||
for i in range(len(x)):
|
||||
if not compare_nested_lists(x[i], y[i]):
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return x == y
|
||||
|
||||
# Test 1D input
|
||||
arrays = [
|
||||
[1],
|
||||
[1, 2, 3],
|
||||
[1, 2, 3, 4],
|
||||
[[1], [2], [3]],
|
||||
[[1, 2], [3, 4]],
|
||||
[[1, 2, 3], [4, 5, 6]],
|
||||
[[[[1]], [[2]], [[3]]]],
|
||||
]
|
||||
|
||||
for array in arrays:
|
||||
mx_res = mx.atleast_2d(mx.array(array))
|
||||
np_res = np.atleast_2d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
|
||||
def test_atleast_3d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
if isinstance(x, list) and isinstance(y, list):
|
||||
if len(x) != len(y):
|
||||
return False
|
||||
for i in range(len(x)):
|
||||
if not compare_nested_lists(x[i], y[i]):
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return x == y
|
||||
|
||||
# Test 1D input
|
||||
arrays = [
|
||||
[1],
|
||||
[1, 2, 3],
|
||||
[1, 2, 3, 4],
|
||||
[[1], [2], [3]],
|
||||
[[1, 2], [3, 4]],
|
||||
[[1, 2, 3], [4, 5, 6]],
|
||||
[[[[1]], [[2]], [[3]]]],
|
||||
]
|
||||
|
||||
for array in arrays:
|
||||
mx_res = mx.atleast_3d(mx.array(array))
|
||||
np_res = np.atleast_3d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user