mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-05 11:28:12 +08:00
[CUDA] Fix back-end bugs and enable corresponding tests (#2296)
* Fix some cuda back-end bugs and enable corresponding tests * more fixes * enable more tests * format
This commit is contained in:
@@ -2586,17 +2586,6 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqualArray(result, mx.array(expected))
|
||||
|
||||
def test_atleast_1d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
if isinstance(x, list) and isinstance(y, list):
|
||||
if len(x) != len(y):
|
||||
return False
|
||||
for i in range(len(x)):
|
||||
if not compare_nested_lists(x[i], y[i]):
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return x == y
|
||||
|
||||
# Test 1D input
|
||||
arrays = [
|
||||
[1],
|
||||
@@ -2614,23 +2603,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
for i, array in enumerate(arrays):
|
||||
mx_res = mx.atleast_1d(mx.array(array))
|
||||
np_res = np.atleast_1d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
|
||||
|
||||
def test_atleast_2d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
if isinstance(x, list) and isinstance(y, list):
|
||||
if len(x) != len(y):
|
||||
return False
|
||||
for i in range(len(x)):
|
||||
if not compare_nested_lists(x[i], y[i]):
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return x == y
|
||||
|
||||
# Test 1D input
|
||||
arrays = [
|
||||
[1],
|
||||
@@ -2648,23 +2625,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
for i, array in enumerate(arrays):
|
||||
mx_res = mx.atleast_2d(mx.array(array))
|
||||
np_res = np.atleast_2d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
|
||||
|
||||
def test_atleast_3d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
if isinstance(x, list) and isinstance(y, list):
|
||||
if len(x) != len(y):
|
||||
return False
|
||||
for i in range(len(x)):
|
||||
if not compare_nested_lists(x[i], y[i]):
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return x == y
|
||||
|
||||
# Test 1D input
|
||||
arrays = [
|
||||
[1],
|
||||
@@ -2682,10 +2647,9 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
for i, array in enumerate(arrays):
|
||||
mx_res = mx.atleast_3d(mx.array(array))
|
||||
np_res = np.atleast_3d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
|
||||
|
||||
def test_issubdtype(self):
|
||||
self.assertTrue(mx.issubdtype(mx.bfloat16, mx.inexact))
|
||||
|
||||
Reference in New Issue
Block a user