mlx/python/tests/mlx_tests.py

110 lines
3.4 KiB
Python
Raw Normal View History

2023-12-01 03:12:53 +08:00
# Copyright © 2023 Apple Inc.
2023-11-30 02:52:08 +08:00
import os
import platform
2023-11-30 02:52:08 +08:00
import unittest
2024-01-04 11:33:19 +08:00
from typing import Any, Callable, List, Tuple, Union
2023-11-30 02:52:08 +08:00
import mlx.core as mx
import numpy as np
2023-11-30 02:52:08 +08:00
class MLXTestRunner(unittest.TestProgram):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def createTests(self, *args, **kwargs):
super().createTests(*args, **kwargs)
# Asume CUDA backend in this case
device = os.getenv("DEVICE", None)
if device is not None:
device = getattr(mx, device)
else:
device = mx.default_device()
if not (device == mx.gpu and not mx.metal.is_available()):
return
from cuda_skip import cuda_skip
filtered_suite = unittest.TestSuite()
def filter_and_add(t):
if isinstance(t, unittest.TestSuite):
for sub_t in t:
filter_and_add(sub_t)
else:
t_id = ".".join(t.id().split(".")[-2:])
if t_id in cuda_skip:
print(f"Skipping {t_id}")
else:
filtered_suite.addTest(t)
filter_and_add(self.test)
self.test = filtered_suite
2023-11-30 02:52:08 +08:00
class MLXTestCase(unittest.TestCase):
@property
def is_apple_silicon(self):
return platform.machine() == "arm64" and platform.system() == "Darwin"
2023-11-30 02:52:08 +08:00
def setUp(self):
self.default = mx.default_device()
device = os.getenv("DEVICE", None)
if device is not None:
device = getattr(mx, device)
mx.set_default_device(device)
def tearDown(self):
mx.set_default_device(self.default)
2024-01-13 15:03:16 +08:00
# Note if a tuple is passed into args, it will be considered a shape request and convert to a mx.random.normal with the shape matching the tuple
2024-01-04 11:19:19 +08:00
def assertCmpNumpy(
self,
2024-01-13 15:03:16 +08:00
args: List[Union[Tuple[int], Any]],
2024-01-04 11:19:19 +08:00
mx_fn: Callable[..., mx.array],
np_fn: Callable[..., np.array],
atol=1e-2,
rtol=1e-2,
dtype=mx.float32,
**kwargs,
):
assert dtype != mx.bfloat16, "numpy does not support bfloat16"
args = [
mx.random.normal(s, dtype=dtype) if isinstance(s, Tuple) else s
2024-01-13 15:03:16 +08:00
for s in args
2024-01-04 11:19:19 +08:00
]
mx_res = mx_fn(*args, **kwargs)
np_res = np_fn(
*[np.array(a) if isinstance(a, mx.array) else a for a in args], **kwargs
)
return self.assertEqualArray(mx_res, mx.array(np_res), atol=atol, rtol=rtol)
def assertEqualArray(
self,
mx_res: mx.array,
expected: mx.array,
atol=1e-2,
rtol=1e-2,
):
self.assertEqual(
tuple(mx_res.shape),
tuple(expected.shape),
msg=f"shape mismatch expected={expected.shape} got={mx_res.shape}",
)
self.assertEqual(
mx_res.dtype,
expected.dtype,
msg=f"dtype mismatch expected={expected.dtype} got={mx_res.dtype}",
)
if not isinstance(mx_res, mx.array) and not isinstance(expected, mx.array):
np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol)
return
elif not isinstance(mx_res, mx.array):
mx_res = mx.array(mx_res)
elif not isinstance(expected, mx.array):
expected = mx.array(expected)
self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol))