2023-12-01 03:12:53 +08:00
|
|
|
# Copyright © 2023 Apple Inc.
|
|
|
|
|
2023-11-30 02:52:08 +08:00
|
|
|
import os
|
|
|
|
import unittest
|
2023-12-12 12:20:58 +08:00
|
|
|
from typing import Callable, List, Tuple, Union
|
2023-11-30 02:52:08 +08:00
|
|
|
|
|
|
|
import mlx.core as mx
|
2023-12-12 11:40:57 +08:00
|
|
|
import numpy as np
|
2023-11-30 02:52:08 +08:00
|
|
|
|
|
|
|
|
|
|
|
class MLXTestCase(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
|
|
self.default = mx.default_device()
|
|
|
|
device = os.getenv("DEVICE", None)
|
|
|
|
if device is not None:
|
|
|
|
device = getattr(mx, device)
|
|
|
|
mx.set_default_device(device)
|
|
|
|
|
|
|
|
def tearDown(self):
|
|
|
|
mx.set_default_device(self.default)
|
2023-12-12 11:40:57 +08:00
|
|
|
|
|
|
|
def assertEqualArray(
|
|
|
|
self,
|
2023-12-12 12:20:58 +08:00
|
|
|
args: List[Union[mx.array, float, int]],
|
2023-12-12 11:40:57 +08:00
|
|
|
mlx_func: Callable[..., mx.array],
|
|
|
|
expected: mx.array,
|
|
|
|
atol=1e-2,
|
|
|
|
rtol=1e-2,
|
|
|
|
):
|
|
|
|
mx_res = mlx_func(*args)
|
|
|
|
assert tuple(mx_res.shape) == tuple(expected.shape), "shape mismatch"
|
|
|
|
assert mx_res.dtype == expected.dtype, "dtype mismatch"
|
|
|
|
np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol)
|