[CUDA] Fix back-end bugs and enable corresponding tests (#2296)

* Fix some cuda back-end bugs and enable corresponding tests

* more fixes

* enable more tests

* format
This commit is contained in:
Awni Hannun
2025-06-16 08:45:40 -07:00
committed by GitHub
parent 4fda5fbdf9
commit c552ff2451
16 changed files with 115 additions and 98 deletions

View File

@@ -1,6 +1,5 @@
cuda_skip = {
"TestArray.test_api",
"TestArray.test_setitem",
"TestAutograd.test_cumprod_grad",
"TestAutograd.test_slice_grads",
"TestAutograd.test_split_against_slice",
@@ -51,7 +50,6 @@ cuda_skip = {
"TestEinsum.test_opt_einsum_test_cases",
"TestEval.test_multi_output_eval_during_transform",
"TestExportImport.test_export_conv",
"TestFast.test_rope_grad",
"TestFFT.test_fft",
"TestFFT.test_fft_big_powers_of_two",
"TestFFT.test_fft_contiguity",
@@ -89,9 +87,6 @@ cuda_skip = {
"TestOps.test_argpartition",
"TestOps.test_array_equal",
"TestOps.test_as_strided",
"TestOps.test_atleast_1d",
"TestOps.test_atleast_2d",
"TestOps.test_atleast_3d",
"TestOps.test_binary_ops",
"TestOps.test_bitwise_grad",
"TestOps.test_complex_ops",
@@ -100,22 +95,16 @@ cuda_skip = {
"TestOps.test_hadamard",
"TestOps.test_hadamard_grad_vmap",
"TestOps.test_irregular_binary_ops",
"TestOps.test_isfinite",
"TestOps.test_kron",
"TestOps.test_log",
"TestOps.test_log10",
"TestOps.test_log1p",
"TestOps.test_log2",
"TestOps.test_logaddexp",
"TestOps.test_logcumsumexp",
"TestOps.test_partition",
"TestOps.test_scans",
"TestOps.test_slice_update_reversed",
"TestOps.test_softmax",
"TestOps.test_sort",
"TestOps.test_tensordot",
"TestOps.test_tile",
"TestOps.test_view",
"TestQuantized.test_gather_matmul_grad",
"TestQuantized.test_gather_qmm",
"TestQuantized.test_gather_qmm_sorted",
@@ -136,7 +125,6 @@ cuda_skip = {
"TestReduce.test_expand_sums",
"TestReduce.test_many_reduction_axes",
"TestUpsample.test_torch_upsample",
"TestVmap.test_unary",
"TestVmap.test_vmap_conv",
"TestVmap.test_vmap_inverse",
"TestVmap.test_vmap_svd",

View File

@@ -1187,7 +1187,7 @@ class TestArray(mlx_tests.MLXTestCase):
check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))
check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))
check_slices(
np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 0, 1])
np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 2, 1])
)
# Multiple slices

View File

@@ -2586,17 +2586,6 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqualArray(result, mx.array(expected))
def test_atleast_1d(self):
def compare_nested_lists(x, y):
if isinstance(x, list) and isinstance(y, list):
if len(x) != len(y):
return False
for i in range(len(x)):
if not compare_nested_lists(x[i], y[i]):
return False
return True
else:
return x == y
# Test 1D input
arrays = [
[1],
@@ -2614,23 +2603,11 @@ class TestOps(mlx_tests.MLXTestCase):
for i, array in enumerate(arrays):
mx_res = mx.atleast_1d(mx.array(array))
np_res = np.atleast_1d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim)
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
def test_atleast_2d(self):
def compare_nested_lists(x, y):
if isinstance(x, list) and isinstance(y, list):
if len(x) != len(y):
return False
for i in range(len(x)):
if not compare_nested_lists(x[i], y[i]):
return False
return True
else:
return x == y
# Test 1D input
arrays = [
[1],
@@ -2648,23 +2625,11 @@ class TestOps(mlx_tests.MLXTestCase):
for i, array in enumerate(arrays):
mx_res = mx.atleast_2d(mx.array(array))
np_res = np.atleast_2d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim)
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
def test_atleast_3d(self):
def compare_nested_lists(x, y):
if isinstance(x, list) and isinstance(y, list):
if len(x) != len(y):
return False
for i in range(len(x)):
if not compare_nested_lists(x[i], y[i]):
return False
return True
else:
return x == y
# Test 1D input
arrays = [
[1],
@@ -2682,10 +2647,9 @@ class TestOps(mlx_tests.MLXTestCase):
for i, array in enumerate(arrays):
mx_res = mx.atleast_3d(mx.array(array))
np_res = np.atleast_3d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim)
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
def test_issubdtype(self):
self.assertTrue(mx.issubdtype(mx.bfloat16, mx.inexact))