simple numpy helper for tests (#352)

This commit is contained in:
Diogo
2024-01-03 22:19:19 -05:00
committed by GitHub
parent 526466dd09
commit 1ac18eac20
2 changed files with 74 additions and 76 deletions

View File

@@ -2,6 +2,7 @@
import os
import unittest
from typing import Any, Callable, List, Tuple
import mlx.core as mx
import numpy as np
@@ -18,13 +19,33 @@ class MLXTestCase(unittest.TestCase):
def tearDown(self):
mx.set_default_device(self.default)
def assertCmpNumpy(
self,
shape: List[Tuple[int] | Any],
mx_fn: Callable[..., mx.array],
np_fn: Callable[..., np.array],
atol=1e-2,
rtol=1e-2,
dtype=mx.float32,
**kwargs,
):
assert dtype != mx.bfloat16, "numpy does not support bfloat16"
args = [
mx.random.normal(s, dtype=dtype) if isinstance(s, Tuple) else s
for s in shape
]
mx_res = mx_fn(*args, **kwargs)
np_res = np_fn(
*[np.array(a) if isinstance(a, mx.array) else a for a in args], **kwargs
)
return self.assertEqualArray(mx_res, mx.array(np_res), atol=atol, rtol=rtol)
def assertEqualArray(
self,
mx_res: mx.array,
expected: mx.array,
atol=1e-2,
rtol=1e-2,
**kwargs,
):
assert tuple(mx_res.shape) == tuple(
expected.shape