mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-02 17:28:12 +08:00
simple numpy helper for tests (#352)
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from typing import Any, Callable, List, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
@@ -18,13 +19,33 @@ class MLXTestCase(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
mx.set_default_device(self.default)
|
||||
|
||||
def assertCmpNumpy(
|
||||
self,
|
||||
shape: List[Tuple[int] | Any],
|
||||
mx_fn: Callable[..., mx.array],
|
||||
np_fn: Callable[..., np.array],
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
dtype=mx.float32,
|
||||
**kwargs,
|
||||
):
|
||||
assert dtype != mx.bfloat16, "numpy does not support bfloat16"
|
||||
args = [
|
||||
mx.random.normal(s, dtype=dtype) if isinstance(s, Tuple) else s
|
||||
for s in shape
|
||||
]
|
||||
mx_res = mx_fn(*args, **kwargs)
|
||||
np_res = np_fn(
|
||||
*[np.array(a) if isinstance(a, mx.array) else a for a in args], **kwargs
|
||||
)
|
||||
return self.assertEqualArray(mx_res, mx.array(np_res), atol=atol, rtol=rtol)
|
||||
|
||||
def assertEqualArray(
|
||||
self,
|
||||
mx_res: mx.array,
|
||||
expected: mx.array,
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
**kwargs,
|
||||
):
|
||||
assert tuple(mx_res.shape) == tuple(
|
||||
expected.shape
|
||||
|
||||
Reference in New Issue
Block a user