mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
simple numpy helper for tests (#352)
This commit is contained in:
parent
526466dd09
commit
1ac18eac20
@ -2,6 +2,7 @@
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from typing import Any, Callable, List, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
@ -18,13 +19,33 @@ class MLXTestCase(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
mx.set_default_device(self.default)
|
||||
|
||||
def assertCmpNumpy(
|
||||
self,
|
||||
shape: List[Tuple[int] | Any],
|
||||
mx_fn: Callable[..., mx.array],
|
||||
np_fn: Callable[..., np.array],
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
dtype=mx.float32,
|
||||
**kwargs,
|
||||
):
|
||||
assert dtype != mx.bfloat16, "numpy does not support bfloat16"
|
||||
args = [
|
||||
mx.random.normal(s, dtype=dtype) if isinstance(s, Tuple) else s
|
||||
for s in shape
|
||||
]
|
||||
mx_res = mx_fn(*args, **kwargs)
|
||||
np_res = np_fn(
|
||||
*[np.array(a) if isinstance(a, mx.array) else a for a in args], **kwargs
|
||||
)
|
||||
return self.assertEqualArray(mx_res, mx.array(np_res), atol=atol, rtol=rtol)
|
||||
|
||||
def assertEqualArray(
|
||||
self,
|
||||
mx_res: mx.array,
|
||||
expected: mx.array,
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
**kwargs,
|
||||
):
|
||||
assert tuple(mx_res.shape) == tuple(
|
||||
expected.shape
|
||||
|
@ -324,24 +324,18 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
def test_tri(self):
|
||||
for shape in [[4], [4, 4], [2, 10]]:
|
||||
for diag in [-1, 0, 1, -2]:
|
||||
self.assertEqualArray(
|
||||
mx.tri(*shape, k=diag), mx.array(np.tri(*shape, k=diag))
|
||||
)
|
||||
self.assertCmpNumpy(shape, mx.tri, np.tri, k=diag)
|
||||
|
||||
def test_tril(self):
|
||||
mt = mx.random.normal((10, 10))
|
||||
nt = np.array(mt)
|
||||
for diag in [-1, 0, 1, -2]:
|
||||
self.assertEqualArray(mx.tril(mt, diag), mx.array(np.tril(nt, diag)))
|
||||
self.assertCmpNumpy([(10, 10)], mx.tril, np.tril, k=diag)
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
mx.tril(mx.zeros((1)))
|
||||
|
||||
def test_triu(self):
|
||||
mt = mx.random.normal((10, 10))
|
||||
nt = np.array(mt)
|
||||
for diag in [-1, 0, 1, -2]:
|
||||
self.assertEqualArray(mx.triu(mt, diag), mx.array(np.triu(nt, diag)))
|
||||
self.assertCmpNumpy([(10, 10)], mx.triu, np.triu, k=diag)
|
||||
with self.assertRaises(Exception):
|
||||
mx.triu(mx.zeros((1)))
|
||||
|
||||
@ -1260,20 +1254,17 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(a_bwd[4:-2, 2:-4], df[0]).item())
|
||||
|
||||
def test_where(self):
|
||||
a = mx.array([[1, 2], [3, 4]])
|
||||
out = mx.where(True, a, 1)
|
||||
out_np = np.where(True, a, 1)
|
||||
self.assertTrue(np.array_equal(out, out_np))
|
||||
|
||||
out = mx.where(True, 1, a)
|
||||
out_np = np.where(True, 1, a)
|
||||
self.assertTrue(np.array_equal(out, out_np))
|
||||
|
||||
condition = mx.array([[True, False], [False, True]])
|
||||
b = mx.array([5, 6])
|
||||
out = mx.where(condition, a, b)
|
||||
out_np = np.where(condition, a, b)
|
||||
self.assertTrue(np.array_equal(out, out_np))
|
||||
self.assertCmpNumpy([True, mx.array([[1, 2], [3, 4]]), 1], mx.where, np.where)
|
||||
self.assertCmpNumpy([True, 1, mx.array([[1, 2], [3, 4]])], mx.where, np.where)
|
||||
self.assertCmpNumpy(
|
||||
[
|
||||
mx.array([[True, False], [False, True]]),
|
||||
mx.array([[1, 2], [3, 4]]),
|
||||
mx.array([5, 6]),
|
||||
],
|
||||
mx.where,
|
||||
np.where,
|
||||
)
|
||||
|
||||
def test_as_strided(self):
|
||||
x_npy = np.random.randn(128).astype(np.float32)
|
||||
@ -1408,24 +1399,13 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual((a + b)[0, 0].item(), 2)
|
||||
|
||||
def test_eye(self):
|
||||
eye_matrix = mx.eye(3)
|
||||
np_eye_matrix = np.eye(3)
|
||||
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
|
||||
|
||||
self.assertCmpNumpy([3], mx.eye, np.eye)
|
||||
# Test for non-square matrix
|
||||
eye_matrix = mx.eye(3, 4)
|
||||
np_eye_matrix = np.eye(3, 4)
|
||||
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
|
||||
|
||||
self.assertCmpNumpy([3, 4], mx.eye, np.eye)
|
||||
# Test with positive k parameter
|
||||
eye_matrix = mx.eye(3, 4, k=1)
|
||||
np_eye_matrix = np.eye(3, 4, k=1)
|
||||
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
|
||||
|
||||
self.assertCmpNumpy([3, 4], mx.eye, np.eye, k=1)
|
||||
# Test with negative k parameter
|
||||
eye_matrix = mx.eye(5, 6, k=-2)
|
||||
np_eye_matrix = np.eye(5, 6, k=-2)
|
||||
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
|
||||
self.assertCmpNumpy([5, 6], mx.eye, np.eye, k=-2)
|
||||
|
||||
def test_stack(self):
|
||||
a = mx.ones((2,))
|
||||
@ -1518,50 +1498,47 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
|
||||
def test_repeat(self):
|
||||
# Setup data for the tests
|
||||
data = np.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]])
|
||||
data = mx.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]])
|
||||
# Test repeat along axis 0
|
||||
repeat_axis_0 = mx.repeat(mx.array(data), 2, axis=0)
|
||||
expected_axis_0 = np.repeat(data, 2, axis=0)
|
||||
|
||||
self.assertEqualArray(repeat_axis_0, mx.array(expected_axis_0))
|
||||
|
||||
self.assertCmpNumpy([data, 2], mx.repeat, np.repeat, axis=0)
|
||||
# Test repeat along axis 1
|
||||
repeat_axis_1 = mx.repeat(mx.array(data), 2, axis=1)
|
||||
expected_axis_1 = np.repeat(data, 2, axis=1)
|
||||
self.assertEqualArray(repeat_axis_1, mx.array(expected_axis_1))
|
||||
|
||||
self.assertCmpNumpy([data, 2], mx.repeat, np.repeat, axis=1)
|
||||
# Test repeat along the last axis (default)
|
||||
repeat_axis_2 = mx.repeat(mx.array(data), 2)
|
||||
expected_axis_2 = np.repeat(data, 2)
|
||||
self.assertEqualArray(repeat_axis_2, mx.array(expected_axis_2))
|
||||
|
||||
self.assertCmpNumpy([data, 2], mx.repeat, np.repeat)
|
||||
# Test repeat with a 1D array along axis 0
|
||||
data_2 = mx.array([1, 3, 2])
|
||||
repeat_2 = mx.repeat(mx.array(data_2), 3, axis=0)
|
||||
expected_2 = np.repeat(data_2, 3)
|
||||
self.assertEqualArray(repeat_2, mx.array(expected_2))
|
||||
|
||||
self.assertCmpNumpy([mx.array([1, 3, 2]), 3], mx.repeat, np.repeat, axis=0)
|
||||
# Test repeat with a 2D array along axis 0
|
||||
data_3 = mx.array([[1, 2, 3], [4, 5, 4], [0, 1, 2]])
|
||||
repeat_3 = mx.repeat(mx.array(data_3), 2, axis=0)
|
||||
expected_3 = np.repeat(data_3, 2, axis=0)
|
||||
self.assertEqualArray(repeat_3, mx.array(expected_3))
|
||||
self.assertCmpNumpy(
|
||||
[mx.array([[1, 2, 3], [4, 5, 4], [0, 1, 2]]), 2],
|
||||
mx.repeat,
|
||||
np.repeat,
|
||||
axis=0,
|
||||
)
|
||||
|
||||
def test_tensordot(self):
|
||||
x = mx.arange(60.0).reshape(3, 4, 5)
|
||||
y = mx.arange(24.0).reshape(4, 3, 2)
|
||||
z = mx.tensordot(x, y, dims=([1, 0], [0, 1]))
|
||||
self.assertEqualArray(z, mx.array(np.tensordot(x, y, axes=([1, 0], [0, 1]))))
|
||||
x = mx.random.normal((3, 4, 5))
|
||||
y = mx.random.normal((4, 5, 6))
|
||||
z = mx.tensordot(x, y, dims=2)
|
||||
self.assertEqualArray(z, mx.array(np.tensordot(x, y, axes=2)))
|
||||
x = mx.random.normal((3, 5, 4, 6))
|
||||
y = mx.random.normal((6, 4, 5, 3))
|
||||
z = mx.tensordot(x, y, dims=([2, 1, 3], [1, 2, 0]))
|
||||
self.assertEqualArray(
|
||||
z, mx.array(np.tensordot(x, y, axes=([2, 1, 3], [1, 2, 0])))
|
||||
)
|
||||
for dtype in [mx.float16, mx.float32]:
|
||||
with self.subTest(dtype=dtype):
|
||||
self.assertCmpNumpy(
|
||||
[(3, 4, 5), (4, 3, 2)],
|
||||
mx.tensordot,
|
||||
lambda x, y, dims: np.tensordot(x, y, axes=dims),
|
||||
dtype=dtype,
|
||||
dims=([1, 0], [0, 1]),
|
||||
)
|
||||
self.assertCmpNumpy(
|
||||
[(3, 4, 5), (4, 5, 6)],
|
||||
mx.tensordot,
|
||||
lambda x, y, dims: np.tensordot(x, y, axes=dims),
|
||||
dtype=dtype,
|
||||
dims=2,
|
||||
)
|
||||
self.assertCmpNumpy(
|
||||
[(3, 5, 4, 6), (6, 4, 5, 3)],
|
||||
mx.tensordot,
|
||||
lambda x, y, dims: np.tensordot(x, y, axes=dims),
|
||||
dtype=dtype,
|
||||
dims=([2, 1, 3], [1, 2, 0]),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user