mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
add python testing for cuda with ability to skip list of tests (#2295)
This commit is contained in:
parent
580776559b
commit
4fda5fbdf9
@ -234,6 +234,7 @@ jobs:
|
||||
command: |
|
||||
source env/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
||||
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
|
||||
|
||||
build_release:
|
||||
parameters:
|
||||
|
5
python/tests/__main__.py
Normal file
5
python/tests/__main__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from . import mlx_tests
|
||||
|
||||
__unittest = True
|
||||
|
||||
mlx_tests.MLXTestRunner(module=None)
|
143
python/tests/cuda_skip.py
Normal file
143
python/tests/cuda_skip.py
Normal file
@ -0,0 +1,143 @@
|
||||
cuda_skip = {
|
||||
"TestArray.test_api",
|
||||
"TestArray.test_setitem",
|
||||
"TestAutograd.test_cumprod_grad",
|
||||
"TestAutograd.test_slice_grads",
|
||||
"TestAutograd.test_split_against_slice",
|
||||
"TestAutograd.test_stop_gradient",
|
||||
"TestAutograd.test_topk_grad",
|
||||
"TestAutograd.test_update_state",
|
||||
"TestAutograd.test_vjp",
|
||||
"TestBF16.test_arg_reduction_ops",
|
||||
"TestBF16.test_binary_ops",
|
||||
"TestBF16.test_reduction_ops",
|
||||
"TestBlas.test_block_masked_matmul",
|
||||
"TestBlas.test_complex_gemm",
|
||||
"TestBlas.test_gather_matmul",
|
||||
"TestBlas.test_gather_matmul_grad",
|
||||
"TestBlas.test_matmul_batched",
|
||||
"TestBlas.test_matrix_vector_attn",
|
||||
"TestCompile.test_compile_dynamic_dims",
|
||||
"TestCompile.test_compile_inf",
|
||||
"TestCompile.test_inf_constant",
|
||||
"TestConv.test_1d_conv_with_2d",
|
||||
"TestConv.test_asymmetric_padding",
|
||||
"TestConv.test_basic_grad_shapes",
|
||||
"TestConv.test_conv2d_unaligned_channels",
|
||||
"TestConv.test_conv_1d_groups_flipped",
|
||||
"TestConv.test_conv_general_flip_grad",
|
||||
"TestConv.test_conv_groups_grad",
|
||||
"TestConv.test_numpy_conv",
|
||||
"TestConv.test_repeated_conv",
|
||||
"TestConv.test_torch_conv_1D",
|
||||
"TestConv.test_torch_conv_1D_grad",
|
||||
"TestConv.test_torch_conv_2D",
|
||||
"TestConv.test_torch_conv_2D_grad",
|
||||
"TestConv.test_torch_conv_3D",
|
||||
"TestConv.test_torch_conv_3D_grad",
|
||||
"TestConv.test_torch_conv_depthwise",
|
||||
"TestConv.test_torch_conv_general",
|
||||
"TestConvTranspose.test_torch_conv_tranpose_1d_output_padding",
|
||||
"TestConvTranspose.test_torch_conv_transpose_1D",
|
||||
"TestConvTranspose.test_torch_conv_transpose_1D_grad",
|
||||
"TestConvTranspose.test_torch_conv_transpose_2D",
|
||||
"TestConvTranspose.test_torch_conv_transpose_2D_grad",
|
||||
"TestConvTranspose.test_torch_conv_transpose_2d_output_padding",
|
||||
"TestConvTranspose.test_torch_conv_transpose_3D",
|
||||
"TestConvTranspose.test_torch_conv_transpose_3D_grad",
|
||||
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
|
||||
"TestEinsum.test_attention",
|
||||
"TestEinsum.test_ellipses",
|
||||
"TestEinsum.test_opt_einsum_test_cases",
|
||||
"TestEval.test_multi_output_eval_during_transform",
|
||||
"TestExportImport.test_export_conv",
|
||||
"TestFast.test_rope_grad",
|
||||
"TestFFT.test_fft",
|
||||
"TestFFT.test_fft_big_powers_of_two",
|
||||
"TestFFT.test_fft_contiguity",
|
||||
"TestFFT.test_fft_exhaustive",
|
||||
"TestFFT.test_fft_grads",
|
||||
"TestFFT.test_fft_into_ifft",
|
||||
"TestFFT.test_fft_large_numbers",
|
||||
"TestFFT.test_fft_shared_mem",
|
||||
"TestFFT.test_fftn",
|
||||
"TestInit.test_orthogonal",
|
||||
"TestLinalg.test_cholesky",
|
||||
"TestLinalg.test_cholesky_inv",
|
||||
"TestLinalg.test_eig",
|
||||
"TestLinalg.test_eigh",
|
||||
"TestLinalg.test_inverse",
|
||||
"TestLinalg.test_lu",
|
||||
"TestLinalg.test_lu_factor",
|
||||
"TestLinalg.test_pseudo_inverse",
|
||||
"TestLinalg.test_qr_factorization",
|
||||
"TestLinalg.test_svd_decomposition",
|
||||
"TestLinalg.test_tri_inverse",
|
||||
"TestLoad.test_load_f8_e4m3",
|
||||
"TestLosses.test_binary_cross_entropy",
|
||||
"TestMemory.test_memory_info",
|
||||
"TestLayers.test_conv1d",
|
||||
"TestLayers.test_conv2d",
|
||||
"TestLayers.test_elu",
|
||||
"TestLayers.test_group_norm",
|
||||
"TestLayers.test_hard_shrink",
|
||||
"TestLayers.test_pooling",
|
||||
"TestLayers.test_quantized_embedding",
|
||||
"TestLayers.test_sin_pe",
|
||||
"TestLayers.test_softshrink",
|
||||
"TestLayers.test_upsample",
|
||||
"TestOps.test_argpartition",
|
||||
"TestOps.test_array_equal",
|
||||
"TestOps.test_as_strided",
|
||||
"TestOps.test_atleast_1d",
|
||||
"TestOps.test_atleast_2d",
|
||||
"TestOps.test_atleast_3d",
|
||||
"TestOps.test_binary_ops",
|
||||
"TestOps.test_bitwise_grad",
|
||||
"TestOps.test_complex_ops",
|
||||
"TestOps.test_divmod",
|
||||
"TestOps.test_dynamic_slicing",
|
||||
"TestOps.test_hadamard",
|
||||
"TestOps.test_hadamard_grad_vmap",
|
||||
"TestOps.test_irregular_binary_ops",
|
||||
"TestOps.test_isfinite",
|
||||
"TestOps.test_kron",
|
||||
"TestOps.test_log",
|
||||
"TestOps.test_log10",
|
||||
"TestOps.test_log1p",
|
||||
"TestOps.test_log2",
|
||||
"TestOps.test_logaddexp",
|
||||
"TestOps.test_logcumsumexp",
|
||||
"TestOps.test_partition",
|
||||
"TestOps.test_scans",
|
||||
"TestOps.test_slice_update_reversed",
|
||||
"TestOps.test_softmax",
|
||||
"TestOps.test_sort",
|
||||
"TestOps.test_tensordot",
|
||||
"TestOps.test_tile",
|
||||
"TestOps.test_view",
|
||||
"TestQuantized.test_gather_matmul_grad",
|
||||
"TestQuantized.test_gather_qmm",
|
||||
"TestQuantized.test_gather_qmm_sorted",
|
||||
"TestQuantized.test_non_multiples",
|
||||
"TestQuantized.test_qmm",
|
||||
"TestQuantized.test_qmm_jvp",
|
||||
"TestQuantized.test_qmm_shapes",
|
||||
"TestQuantized.test_qmm_vjp",
|
||||
"TestQuantized.test_qmv",
|
||||
"TestQuantized.test_quantize_dequantize",
|
||||
"TestQuantized.test_qvm",
|
||||
"TestQuantized.test_qvm_splitk",
|
||||
"TestQuantized.test_small_matrix",
|
||||
"TestQuantized.test_throw",
|
||||
"TestQuantized.test_vjp_scales_biases",
|
||||
"TestReduce.test_axis_permutation_sums",
|
||||
"TestReduce.test_dtypes",
|
||||
"TestReduce.test_expand_sums",
|
||||
"TestReduce.test_many_reduction_axes",
|
||||
"TestUpsample.test_torch_upsample",
|
||||
"TestVmap.test_unary",
|
||||
"TestVmap.test_vmap_conv",
|
||||
"TestVmap.test_vmap_inverse",
|
||||
"TestVmap.test_vmap_svd",
|
||||
}
|
@ -9,6 +9,42 @@ import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MLXTestRunner(unittest.TestProgram):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def createTests(self, *args, **kwargs):
|
||||
super().createTests(*args, **kwargs)
|
||||
|
||||
# Asume CUDA backend in this case
|
||||
device = os.getenv("DEVICE", None)
|
||||
if device is not None:
|
||||
device = getattr(mx, device)
|
||||
else:
|
||||
device = mx.default_device()
|
||||
|
||||
if not (device == mx.gpu and not mx.metal.is_available()):
|
||||
return
|
||||
|
||||
from cuda_skip import cuda_skip
|
||||
|
||||
filtered_suite = unittest.TestSuite()
|
||||
|
||||
def filter_and_add(t):
|
||||
if isinstance(t, unittest.TestSuite):
|
||||
for sub_t in t:
|
||||
filter_and_add(sub_t)
|
||||
else:
|
||||
t_id = ".".join(t.id().split(".")[-2:])
|
||||
if t_id in cuda_skip:
|
||||
print(f"Skipping {t_id}")
|
||||
else:
|
||||
filtered_suite.addTest(t)
|
||||
|
||||
filter_and_add(self.test)
|
||||
self.test = filtered_suite
|
||||
|
||||
|
||||
class MLXTestCase(unittest.TestCase):
|
||||
@property
|
||||
def is_apple_silicon(self):
|
||||
|
@ -130,4 +130,4 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -2033,4 +2033,4 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -799,4 +799,4 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -193,4 +193,4 @@ class TestBF16(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -1236,4 +1236,4 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -981,4 +981,4 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -38,4 +38,4 @@ class TestConstants(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -1188,4 +1188,4 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -807,4 +807,4 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -38,7 +38,7 @@ class TestDevice(mlx_tests.MLXTestCase):
|
||||
# Restore device
|
||||
mx.set_default_device(device)
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
@unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available")
|
||||
def test_device_context(self):
|
||||
default = mx.default_device()
|
||||
diff = mx.cpu if default == mx.gpu else mx.gpu
|
||||
@ -114,4 +114,4 @@ class TestStream(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -294,4 +294,4 @@ class TestDouble(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -360,4 +360,4 @@ class TestEinsum(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -172,7 +172,7 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
post = mx.get_peak_memory()
|
||||
self.assertEqual(pre, post)
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
@unittest.skipIf(not mx.is_available(mx.gpu), "GPU is not available")
|
||||
def test_multistream_deadlock(self):
|
||||
s1 = mx.default_stream(mx.gpu)
|
||||
s2 = mx.new_stream(mx.gpu)
|
||||
@ -197,4 +197,4 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -348,4 +348,4 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -772,4 +772,4 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -607,7 +607,7 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
def test_sdpa_prommote_mask(self):
|
||||
def test_sdpa_promote_mask(self):
|
||||
mask = mx.array(2.0, mx.bfloat16)
|
||||
D = 64
|
||||
Nq = 4
|
||||
@ -653,4 +653,4 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(failfast=True)
|
||||
mlx_tests.MLXTestRunner(failfast=True)
|
||||
|
@ -320,4 +320,4 @@ class TestFFT(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -34,4 +34,4 @@ class TestGraph(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -136,4 +136,4 @@ class TestInit(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -545,4 +545,4 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -400,4 +400,4 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -414,4 +414,4 @@ class TestLosses(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -60,4 +60,4 @@ class TestMemory(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -1907,4 +1907,4 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -3127,4 +3127,4 @@ class TestBroadcast(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -527,4 +527,4 @@ class TestSchedulers(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -576,4 +576,4 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -389,4 +389,4 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -155,4 +155,4 @@ class TestReduce(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(failfast=True)
|
||||
mlx_tests.MLXTestRunner(failfast=True)
|
||||
|
@ -48,4 +48,4 @@ class TestTreeUtils(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -97,4 +97,4 @@ class TestUpsample(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
@ -725,4 +725,4 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
Loading…
Reference in New Issue
Block a user