mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-11 06:24:35 +08:00
add python testing for cuda with ability to skip list of tests (#2295)
This commit is contained in:
@@ -9,6 +9,42 @@ import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MLXTestRunner(unittest.TestProgram):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def createTests(self, *args, **kwargs):
|
||||
super().createTests(*args, **kwargs)
|
||||
|
||||
# Asume CUDA backend in this case
|
||||
device = os.getenv("DEVICE", None)
|
||||
if device is not None:
|
||||
device = getattr(mx, device)
|
||||
else:
|
||||
device = mx.default_device()
|
||||
|
||||
if not (device == mx.gpu and not mx.metal.is_available()):
|
||||
return
|
||||
|
||||
from cuda_skip import cuda_skip
|
||||
|
||||
filtered_suite = unittest.TestSuite()
|
||||
|
||||
def filter_and_add(t):
|
||||
if isinstance(t, unittest.TestSuite):
|
||||
for sub_t in t:
|
||||
filter_and_add(sub_t)
|
||||
else:
|
||||
t_id = ".".join(t.id().split(".")[-2:])
|
||||
if t_id in cuda_skip:
|
||||
print(f"Skipping {t_id}")
|
||||
else:
|
||||
filtered_suite.addTest(t)
|
||||
|
||||
filter_and_add(self.test)
|
||||
self.test = filtered_suite
|
||||
|
||||
|
||||
class MLXTestCase(unittest.TestCase):
|
||||
@property
|
||||
def is_apple_silicon(self):
|
||||
|
Reference in New Issue
Block a user