mlx/python/tests/mlx_tests.py
Awni Hannun 8993382aaa
Buffer Donation (#519)
* buffer donation

* fix to move shared pointer

* format

* gpu in place for copy and binary

* revert ops test

* cpu in place

* a little cleanup

* remove useless bench
2024-01-26 16:30:33 -08:00

74 lines
2.3 KiB
Python

# Copyright © 2023 Apple Inc.
import os
import platform
import unittest
from typing import Any, Callable, List, Tuple, Union
import mlx.core as mx
import numpy as np
class MLXTestCase(unittest.TestCase):
@property
def is_linux(self):
return platform.system() == "Linux"
def setUp(self):
self.default = mx.default_device()
device = os.getenv("DEVICE", None)
if device is not None:
device = getattr(mx, device)
mx.set_default_device(device)
def tearDown(self):
mx.set_default_device(self.default)
# Note if a tuple is passed into args, it will be considered a shape request and convert to a mx.random.normal with the shape matching the tuple
def assertCmpNumpy(
self,
args: List[Union[Tuple[int], Any]],
mx_fn: Callable[..., mx.array],
np_fn: Callable[..., np.array],
atol=1e-2,
rtol=1e-2,
dtype=mx.float32,
**kwargs,
):
assert dtype != mx.bfloat16, "numpy does not support bfloat16"
args = [
mx.random.normal(s, dtype=dtype) if isinstance(s, Tuple) else s
for s in args
]
mx_res = mx_fn(*args, **kwargs)
np_res = np_fn(
*[np.array(a) if isinstance(a, mx.array) else a for a in args], **kwargs
)
return self.assertEqualArray(mx_res, mx.array(np_res), atol=atol, rtol=rtol)
def assertEqualArray(
self,
mx_res: mx.array,
expected: mx.array,
atol=1e-2,
rtol=1e-2,
):
self.assertEqual(
tuple(mx_res.shape),
tuple(expected.shape),
msg=f"shape mismatch expected={expected.shape} got={mx_res.shape}",
)
self.assertEqual(
mx_res.dtype,
expected.dtype,
msg=f"dtype mismatch expected={expected.dtype} got={mx_res.dtype}",
)
if not isinstance(mx_res, mx.array) and not isinstance(expected, mx.array):
np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol)
return
elif not isinstance(mx_res, mx.array):
mx_res = mx.array(mx_res)
elif not isinstance(expected, mx.array):
expected = mx.array(expected)
self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol))