From 4fda5fbdf94e70eb467b8f0a4900bfdf5f8ce108 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 15 Jun 2025 10:56:48 -0700 Subject: [PATCH] add python testing for cuda with ability to skip list of tests (#2295) --- .circleci/config.yml | 1 + python/tests/__main__.py | 5 + python/tests/cuda_skip.py | 143 ++++++++++++++++++++++++++ python/tests/mlx_tests.py | 36 +++++++ python/tests/ring_test_distributed.py | 2 +- python/tests/test_array.py | 2 +- python/tests/test_autograd.py | 2 +- python/tests/test_bf16.py | 2 +- python/tests/test_blas.py | 2 +- python/tests/test_compile.py | 2 +- python/tests/test_constants.py | 2 +- python/tests/test_conv.py | 2 +- python/tests/test_conv_transpose.py | 2 +- python/tests/test_device.py | 4 +- python/tests/test_double.py | 2 +- python/tests/test_einsum.py | 2 +- python/tests/test_eval.py | 4 +- python/tests/test_export_import.py | 2 +- python/tests/test_fast.py | 2 +- python/tests/test_fast_sdpa.py | 4 +- python/tests/test_fft.py | 2 +- python/tests/test_graph.py | 2 +- python/tests/test_init.py | 2 +- python/tests/test_linalg.py | 2 +- python/tests/test_load.py | 2 +- python/tests/test_losses.py | 2 +- python/tests/test_memory.py | 2 +- python/tests/test_nn.py | 2 +- python/tests/test_ops.py | 2 +- python/tests/test_optimizers.py | 2 +- python/tests/test_quantized.py | 2 +- python/tests/test_random.py | 2 +- python/tests/test_reduce.py | 2 +- python/tests/test_tree.py | 2 +- python/tests/test_upsample.py | 2 +- python/tests/test_vmap.py | 2 +- 36 files changed, 220 insertions(+), 35 deletions(-) create mode 100644 python/tests/__main__.py create mode 100644 python/tests/cuda_skip.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 808242f9b..0ea9303db 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -234,6 +234,7 @@ jobs: command: | source env/bin/activate LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v + LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v build_release: parameters: diff --git a/python/tests/__main__.py b/python/tests/__main__.py new file mode 100644 index 000000000..5230bd535 --- /dev/null +++ b/python/tests/__main__.py @@ -0,0 +1,5 @@ +from . import mlx_tests + +__unittest = True + +mlx_tests.MLXTestRunner(module=None) diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py new file mode 100644 index 000000000..cda396dcb --- /dev/null +++ b/python/tests/cuda_skip.py @@ -0,0 +1,143 @@ +cuda_skip = { + "TestArray.test_api", + "TestArray.test_setitem", + "TestAutograd.test_cumprod_grad", + "TestAutograd.test_slice_grads", + "TestAutograd.test_split_against_slice", + "TestAutograd.test_stop_gradient", + "TestAutograd.test_topk_grad", + "TestAutograd.test_update_state", + "TestAutograd.test_vjp", + "TestBF16.test_arg_reduction_ops", + "TestBF16.test_binary_ops", + "TestBF16.test_reduction_ops", + "TestBlas.test_block_masked_matmul", + "TestBlas.test_complex_gemm", + "TestBlas.test_gather_matmul", + "TestBlas.test_gather_matmul_grad", + "TestBlas.test_matmul_batched", + "TestBlas.test_matrix_vector_attn", + "TestCompile.test_compile_dynamic_dims", + "TestCompile.test_compile_inf", + "TestCompile.test_inf_constant", + "TestConv.test_1d_conv_with_2d", + "TestConv.test_asymmetric_padding", + "TestConv.test_basic_grad_shapes", + "TestConv.test_conv2d_unaligned_channels", + "TestConv.test_conv_1d_groups_flipped", + "TestConv.test_conv_general_flip_grad", + "TestConv.test_conv_groups_grad", + "TestConv.test_numpy_conv", + "TestConv.test_repeated_conv", + "TestConv.test_torch_conv_1D", + "TestConv.test_torch_conv_1D_grad", + "TestConv.test_torch_conv_2D", + "TestConv.test_torch_conv_2D_grad", + "TestConv.test_torch_conv_3D", + "TestConv.test_torch_conv_3D_grad", + "TestConv.test_torch_conv_depthwise", + "TestConv.test_torch_conv_general", + "TestConvTranspose.test_torch_conv_tranpose_1d_output_padding", + "TestConvTranspose.test_torch_conv_transpose_1D", + "TestConvTranspose.test_torch_conv_transpose_1D_grad", + "TestConvTranspose.test_torch_conv_transpose_2D", + "TestConvTranspose.test_torch_conv_transpose_2D_grad", + "TestConvTranspose.test_torch_conv_transpose_2d_output_padding", + "TestConvTranspose.test_torch_conv_transpose_3D", + "TestConvTranspose.test_torch_conv_transpose_3D_grad", + "TestConvTranspose.test_torch_conv_transpose_3d_output_padding", + "TestEinsum.test_attention", + "TestEinsum.test_ellipses", + "TestEinsum.test_opt_einsum_test_cases", + "TestEval.test_multi_output_eval_during_transform", + "TestExportImport.test_export_conv", + "TestFast.test_rope_grad", + "TestFFT.test_fft", + "TestFFT.test_fft_big_powers_of_two", + "TestFFT.test_fft_contiguity", + "TestFFT.test_fft_exhaustive", + "TestFFT.test_fft_grads", + "TestFFT.test_fft_into_ifft", + "TestFFT.test_fft_large_numbers", + "TestFFT.test_fft_shared_mem", + "TestFFT.test_fftn", + "TestInit.test_orthogonal", + "TestLinalg.test_cholesky", + "TestLinalg.test_cholesky_inv", + "TestLinalg.test_eig", + "TestLinalg.test_eigh", + "TestLinalg.test_inverse", + "TestLinalg.test_lu", + "TestLinalg.test_lu_factor", + "TestLinalg.test_pseudo_inverse", + "TestLinalg.test_qr_factorization", + "TestLinalg.test_svd_decomposition", + "TestLinalg.test_tri_inverse", + "TestLoad.test_load_f8_e4m3", + "TestLosses.test_binary_cross_entropy", + "TestMemory.test_memory_info", + "TestLayers.test_conv1d", + "TestLayers.test_conv2d", + "TestLayers.test_elu", + "TestLayers.test_group_norm", + "TestLayers.test_hard_shrink", + "TestLayers.test_pooling", + "TestLayers.test_quantized_embedding", + "TestLayers.test_sin_pe", + "TestLayers.test_softshrink", + "TestLayers.test_upsample", + "TestOps.test_argpartition", + "TestOps.test_array_equal", + "TestOps.test_as_strided", + "TestOps.test_atleast_1d", + "TestOps.test_atleast_2d", + "TestOps.test_atleast_3d", + "TestOps.test_binary_ops", + "TestOps.test_bitwise_grad", + "TestOps.test_complex_ops", + "TestOps.test_divmod", + "TestOps.test_dynamic_slicing", + "TestOps.test_hadamard", + "TestOps.test_hadamard_grad_vmap", + "TestOps.test_irregular_binary_ops", + "TestOps.test_isfinite", + "TestOps.test_kron", + "TestOps.test_log", + "TestOps.test_log10", + "TestOps.test_log1p", + "TestOps.test_log2", + "TestOps.test_logaddexp", + "TestOps.test_logcumsumexp", + "TestOps.test_partition", + "TestOps.test_scans", + "TestOps.test_slice_update_reversed", + "TestOps.test_softmax", + "TestOps.test_sort", + "TestOps.test_tensordot", + "TestOps.test_tile", + "TestOps.test_view", + "TestQuantized.test_gather_matmul_grad", + "TestQuantized.test_gather_qmm", + "TestQuantized.test_gather_qmm_sorted", + "TestQuantized.test_non_multiples", + "TestQuantized.test_qmm", + "TestQuantized.test_qmm_jvp", + "TestQuantized.test_qmm_shapes", + "TestQuantized.test_qmm_vjp", + "TestQuantized.test_qmv", + "TestQuantized.test_quantize_dequantize", + "TestQuantized.test_qvm", + "TestQuantized.test_qvm_splitk", + "TestQuantized.test_small_matrix", + "TestQuantized.test_throw", + "TestQuantized.test_vjp_scales_biases", + "TestReduce.test_axis_permutation_sums", + "TestReduce.test_dtypes", + "TestReduce.test_expand_sums", + "TestReduce.test_many_reduction_axes", + "TestUpsample.test_torch_upsample", + "TestVmap.test_unary", + "TestVmap.test_vmap_conv", + "TestVmap.test_vmap_inverse", + "TestVmap.test_vmap_svd", +} diff --git a/python/tests/mlx_tests.py b/python/tests/mlx_tests.py index f446b5e67..65bd0e873 100644 --- a/python/tests/mlx_tests.py +++ b/python/tests/mlx_tests.py @@ -9,6 +9,42 @@ import mlx.core as mx import numpy as np +class MLXTestRunner(unittest.TestProgram): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def createTests(self, *args, **kwargs): + super().createTests(*args, **kwargs) + + # Asume CUDA backend in this case + device = os.getenv("DEVICE", None) + if device is not None: + device = getattr(mx, device) + else: + device = mx.default_device() + + if not (device == mx.gpu and not mx.metal.is_available()): + return + + from cuda_skip import cuda_skip + + filtered_suite = unittest.TestSuite() + + def filter_and_add(t): + if isinstance(t, unittest.TestSuite): + for sub_t in t: + filter_and_add(sub_t) + else: + t_id = ".".join(t.id().split(".")[-2:]) + if t_id in cuda_skip: + print(f"Skipping {t_id}") + else: + filtered_suite.addTest(t) + + filter_and_add(self.test) + self.test = filtered_suite + + class MLXTestCase(unittest.TestCase): @property def is_apple_silicon(self): diff --git a/python/tests/ring_test_distributed.py b/python/tests/ring_test_distributed.py index 77d45dbad..213f85274 100644 --- a/python/tests/ring_test_distributed.py +++ b/python/tests/ring_test_distributed.py @@ -130,4 +130,4 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_array.py b/python/tests/test_array.py index c22e0a38f..c02b524b4 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -2033,4 +2033,4 @@ class TestArray(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index ec9d957ea..7973d79be 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -799,4 +799,4 @@ class TestAutograd(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_bf16.py b/python/tests/test_bf16.py index 0b4b49919..2e4e2e0c3 100644 --- a/python/tests/test_bf16.py +++ b/python/tests/test_bf16.py @@ -193,4 +193,4 @@ class TestBF16(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 2762df8f8..eb45df124 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -1236,4 +1236,4 @@ class TestBlas(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index f5ce496cd..656553f9d 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -981,4 +981,4 @@ class TestCompile(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_constants.py b/python/tests/test_constants.py index 104e7522d..cfd971fbe 100644 --- a/python/tests/test_constants.py +++ b/python/tests/test_constants.py @@ -38,4 +38,4 @@ class TestConstants(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index c68315a5d..9be22e01b 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -1188,4 +1188,4 @@ class TestConv(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_conv_transpose.py b/python/tests/test_conv_transpose.py index 2085e09d7..7289955ed 100644 --- a/python/tests/test_conv_transpose.py +++ b/python/tests/test_conv_transpose.py @@ -807,4 +807,4 @@ class TestConvTranspose(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_device.py b/python/tests/test_device.py index 6793c98d1..d51028def 100644 --- a/python/tests/test_device.py +++ b/python/tests/test_device.py @@ -38,7 +38,7 @@ class TestDevice(mlx_tests.MLXTestCase): # Restore device mx.set_default_device(device) - @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + @unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available") def test_device_context(self): default = mx.default_device() diff = mx.cpu if default == mx.gpu else mx.gpu @@ -114,4 +114,4 @@ class TestStream(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_double.py b/python/tests/test_double.py index 10fce0db1..fccf3628f 100644 --- a/python/tests/test_double.py +++ b/python/tests/test_double.py @@ -294,4 +294,4 @@ class TestDouble(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_einsum.py b/python/tests/test_einsum.py index 19ea8178e..a73ea3818 100644 --- a/python/tests/test_einsum.py +++ b/python/tests/test_einsum.py @@ -360,4 +360,4 @@ class TestEinsum(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index fcd424343..5d6daaec2 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -172,7 +172,7 @@ class TestEval(mlx_tests.MLXTestCase): post = mx.get_peak_memory() self.assertEqual(pre, post) - @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + @unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available") def test_multistream_deadlock(self): s1 = mx.default_stream(mx.gpu) s2 = mx.new_stream(mx.gpu) @@ -197,4 +197,4 @@ class TestEval(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_export_import.py b/python/tests/test_export_import.py index 0fd8bfd87..099be0cc0 100644 --- a/python/tests/test_export_import.py +++ b/python/tests/test_export_import.py @@ -348,4 +348,4 @@ class TestExportImport(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index 59c2fc3ef..13c65de99 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -772,4 +772,4 @@ class TestFast(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 8f55d41e3..a929e91cf 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -607,7 +607,7 @@ class TestSDPA(mlx_tests.MLXTestCase): out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) - def test_sdpa_prommote_mask(self): + def test_sdpa_promote_mask(self): mask = mx.array(2.0, mx.bfloat16) D = 64 Nq = 4 @@ -653,4 +653,4 @@ class TestSDPA(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main(failfast=True) + mlx_tests.MLXTestRunner(failfast=True) diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index df9d25edc..07ab62672 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -320,4 +320,4 @@ class TestFFT(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_graph.py b/python/tests/test_graph.py index 4b8f6d86a..7c6a11412 100644 --- a/python/tests/test_graph.py +++ b/python/tests/test_graph.py @@ -34,4 +34,4 @@ class TestGraph(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_init.py b/python/tests/test_init.py index 4b209736f..046a6e836 100644 --- a/python/tests/test_init.py +++ b/python/tests/test_init.py @@ -136,4 +136,4 @@ class TestInit(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index 764d11f6e..81a43ed7f 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -545,4 +545,4 @@ class TestLinalg(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_load.py b/python/tests/test_load.py index 341564dae..35f7016c5 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -400,4 +400,4 @@ class TestLoad(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_losses.py b/python/tests/test_losses.py index 102ec857d..cbc657655 100644 --- a/python/tests/test_losses.py +++ b/python/tests/test_losses.py @@ -414,4 +414,4 @@ class TestLosses(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_memory.py b/python/tests/test_memory.py index 7343bdc91..08da7ccc6 100644 --- a/python/tests/test_memory.py +++ b/python/tests/test_memory.py @@ -60,4 +60,4 @@ class TestMemory(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 13e31ad96..10bbe821e 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -1907,4 +1907,4 @@ class TestLayers(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 7c4f3f8e3..02ada39b4 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -3127,4 +3127,4 @@ class TestBroadcast(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py index 4943fe662..e07fc8456 100644 --- a/python/tests/test_optimizers.py +++ b/python/tests/test_optimizers.py @@ -527,4 +527,4 @@ class TestSchedulers(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 3c4f03e4d..f402bd444 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -576,4 +576,4 @@ class TestQuantized(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_random.py b/python/tests/test_random.py index 2fc768651..551c32993 100644 --- a/python/tests/test_random.py +++ b/python/tests/test_random.py @@ -389,4 +389,4 @@ class TestRandom(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_reduce.py b/python/tests/test_reduce.py index 9012216ba..2b899c099 100644 --- a/python/tests/test_reduce.py +++ b/python/tests/test_reduce.py @@ -155,4 +155,4 @@ class TestReduce(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main(failfast=True) + mlx_tests.MLXTestRunner(failfast=True) diff --git a/python/tests/test_tree.py b/python/tests/test_tree.py index 63018fdae..bacf6e71d 100644 --- a/python/tests/test_tree.py +++ b/python/tests/test_tree.py @@ -48,4 +48,4 @@ class TestTreeUtils(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_upsample.py b/python/tests/test_upsample.py index 86f41b6e8..631853cce 100644 --- a/python/tests/test_upsample.py +++ b/python/tests/test_upsample.py @@ -97,4 +97,4 @@ class TestUpsample(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index 52f1a49ad..a88e59585 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -725,4 +725,4 @@ class TestVmap(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() + mlx_tests.MLXTestRunner()