mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 01:18:12 +08:00
added atleast *args input support (#710)
* added atleast list(array) input support * function overloading implemented * Refactoring * fixed formatting * removed pos_only
This commit is contained in:

committed by
GitHub

parent
3b661b7394
commit
08226ab491
@@ -1932,12 +1932,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
[[[[1]], [[2]], [[3]]]],
|
||||
]
|
||||
|
||||
for array in arrays:
|
||||
mx_arrays = [mx.atleast_1d(mx.array(x)) for x in arrays]
|
||||
atleast_arrays = mx.atleast_1d(*mx_arrays)
|
||||
|
||||
for i, array in enumerate(arrays):
|
||||
mx_res = mx.atleast_1d(mx.array(array))
|
||||
np_res = np.atleast_1d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
|
||||
def test_atleast_2d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
@@ -1962,12 +1966,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
[[[[1]], [[2]], [[3]]]],
|
||||
]
|
||||
|
||||
for array in arrays:
|
||||
mx_arrays = [mx.atleast_2d(mx.array(x)) for x in arrays]
|
||||
atleast_arrays = mx.atleast_2d(*mx_arrays)
|
||||
|
||||
for i, array in enumerate(arrays):
|
||||
mx_res = mx.atleast_2d(mx.array(array))
|
||||
np_res = np.atleast_2d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
|
||||
def test_atleast_3d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
@@ -1992,12 +2000,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
[[[[1]], [[2]], [[3]]]],
|
||||
]
|
||||
|
||||
for array in arrays:
|
||||
mx_arrays = [mx.atleast_3d(mx.array(x)) for x in arrays]
|
||||
atleast_arrays = mx.atleast_3d(*mx_arrays)
|
||||
|
||||
for i, array in enumerate(arrays):
|
||||
mx_res = mx.atleast_3d(mx.array(array))
|
||||
np_res = np.atleast_3d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user