simple numpy helper for tests (#352)

This commit is contained in:
Diogo 2024-01-03 22:19:19 -05:00 committed by GitHub
parent 526466dd09
commit 1ac18eac20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 74 additions and 76 deletions

View File

@ -2,6 +2,7 @@
import os
import unittest
from typing import Any, Callable, List, Tuple
import mlx.core as mx
import numpy as np
@ -18,13 +19,33 @@ class MLXTestCase(unittest.TestCase):
def tearDown(self):
mx.set_default_device(self.default)
def assertCmpNumpy(
self,
shape: List[Tuple[int] | Any],
mx_fn: Callable[..., mx.array],
np_fn: Callable[..., np.array],
atol=1e-2,
rtol=1e-2,
dtype=mx.float32,
**kwargs,
):
assert dtype != mx.bfloat16, "numpy does not support bfloat16"
args = [
mx.random.normal(s, dtype=dtype) if isinstance(s, Tuple) else s
for s in shape
]
mx_res = mx_fn(*args, **kwargs)
np_res = np_fn(
*[np.array(a) if isinstance(a, mx.array) else a for a in args], **kwargs
)
return self.assertEqualArray(mx_res, mx.array(np_res), atol=atol, rtol=rtol)
def assertEqualArray(
self,
mx_res: mx.array,
expected: mx.array,
atol=1e-2,
rtol=1e-2,
**kwargs,
):
assert tuple(mx_res.shape) == tuple(
expected.shape

View File

@ -324,24 +324,18 @@ class TestOps(mlx_tests.MLXTestCase):
def test_tri(self):
for shape in [[4], [4, 4], [2, 10]]:
for diag in [-1, 0, 1, -2]:
self.assertEqualArray(
mx.tri(*shape, k=diag), mx.array(np.tri(*shape, k=diag))
)
self.assertCmpNumpy(shape, mx.tri, np.tri, k=diag)
def test_tril(self):
mt = mx.random.normal((10, 10))
nt = np.array(mt)
for diag in [-1, 0, 1, -2]:
self.assertEqualArray(mx.tril(mt, diag), mx.array(np.tril(nt, diag)))
self.assertCmpNumpy([(10, 10)], mx.tril, np.tril, k=diag)
with self.assertRaises(Exception):
mx.tril(mx.zeros((1)))
def test_triu(self):
mt = mx.random.normal((10, 10))
nt = np.array(mt)
for diag in [-1, 0, 1, -2]:
self.assertEqualArray(mx.triu(mt, diag), mx.array(np.triu(nt, diag)))
self.assertCmpNumpy([(10, 10)], mx.triu, np.triu, k=diag)
with self.assertRaises(Exception):
mx.triu(mx.zeros((1)))
@ -1260,20 +1254,17 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(a_bwd[4:-2, 2:-4], df[0]).item())
def test_where(self):
a = mx.array([[1, 2], [3, 4]])
out = mx.where(True, a, 1)
out_np = np.where(True, a, 1)
self.assertTrue(np.array_equal(out, out_np))
out = mx.where(True, 1, a)
out_np = np.where(True, 1, a)
self.assertTrue(np.array_equal(out, out_np))
condition = mx.array([[True, False], [False, True]])
b = mx.array([5, 6])
out = mx.where(condition, a, b)
out_np = np.where(condition, a, b)
self.assertTrue(np.array_equal(out, out_np))
self.assertCmpNumpy([True, mx.array([[1, 2], [3, 4]]), 1], mx.where, np.where)
self.assertCmpNumpy([True, 1, mx.array([[1, 2], [3, 4]])], mx.where, np.where)
self.assertCmpNumpy(
[
mx.array([[True, False], [False, True]]),
mx.array([[1, 2], [3, 4]]),
mx.array([5, 6]),
],
mx.where,
np.where,
)
def test_as_strided(self):
x_npy = np.random.randn(128).astype(np.float32)
@ -1408,24 +1399,13 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual((a + b)[0, 0].item(), 2)
def test_eye(self):
eye_matrix = mx.eye(3)
np_eye_matrix = np.eye(3)
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
self.assertCmpNumpy([3], mx.eye, np.eye)
# Test for non-square matrix
eye_matrix = mx.eye(3, 4)
np_eye_matrix = np.eye(3, 4)
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
self.assertCmpNumpy([3, 4], mx.eye, np.eye)
# Test with positive k parameter
eye_matrix = mx.eye(3, 4, k=1)
np_eye_matrix = np.eye(3, 4, k=1)
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
self.assertCmpNumpy([3, 4], mx.eye, np.eye, k=1)
# Test with negative k parameter
eye_matrix = mx.eye(5, 6, k=-2)
np_eye_matrix = np.eye(5, 6, k=-2)
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
self.assertCmpNumpy([5, 6], mx.eye, np.eye, k=-2)
def test_stack(self):
a = mx.ones((2,))
@ -1518,50 +1498,47 @@ class TestOps(mlx_tests.MLXTestCase):
def test_repeat(self):
# Setup data for the tests
data = np.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]])
data = mx.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]])
# Test repeat along axis 0
repeat_axis_0 = mx.repeat(mx.array(data), 2, axis=0)
expected_axis_0 = np.repeat(data, 2, axis=0)
self.assertEqualArray(repeat_axis_0, mx.array(expected_axis_0))
self.assertCmpNumpy([data, 2], mx.repeat, np.repeat, axis=0)
# Test repeat along axis 1
repeat_axis_1 = mx.repeat(mx.array(data), 2, axis=1)
expected_axis_1 = np.repeat(data, 2, axis=1)
self.assertEqualArray(repeat_axis_1, mx.array(expected_axis_1))
self.assertCmpNumpy([data, 2], mx.repeat, np.repeat, axis=1)
# Test repeat along the last axis (default)
repeat_axis_2 = mx.repeat(mx.array(data), 2)
expected_axis_2 = np.repeat(data, 2)
self.assertEqualArray(repeat_axis_2, mx.array(expected_axis_2))
self.assertCmpNumpy([data, 2], mx.repeat, np.repeat)
# Test repeat with a 1D array along axis 0
data_2 = mx.array([1, 3, 2])
repeat_2 = mx.repeat(mx.array(data_2), 3, axis=0)
expected_2 = np.repeat(data_2, 3)
self.assertEqualArray(repeat_2, mx.array(expected_2))
self.assertCmpNumpy([mx.array([1, 3, 2]), 3], mx.repeat, np.repeat, axis=0)
# Test repeat with a 2D array along axis 0
data_3 = mx.array([[1, 2, 3], [4, 5, 4], [0, 1, 2]])
repeat_3 = mx.repeat(mx.array(data_3), 2, axis=0)
expected_3 = np.repeat(data_3, 2, axis=0)
self.assertEqualArray(repeat_3, mx.array(expected_3))
self.assertCmpNumpy(
[mx.array([[1, 2, 3], [4, 5, 4], [0, 1, 2]]), 2],
mx.repeat,
np.repeat,
axis=0,
)
def test_tensordot(self):
x = mx.arange(60.0).reshape(3, 4, 5)
y = mx.arange(24.0).reshape(4, 3, 2)
z = mx.tensordot(x, y, dims=([1, 0], [0, 1]))
self.assertEqualArray(z, mx.array(np.tensordot(x, y, axes=([1, 0], [0, 1]))))
x = mx.random.normal((3, 4, 5))
y = mx.random.normal((4, 5, 6))
z = mx.tensordot(x, y, dims=2)
self.assertEqualArray(z, mx.array(np.tensordot(x, y, axes=2)))
x = mx.random.normal((3, 5, 4, 6))
y = mx.random.normal((6, 4, 5, 3))
z = mx.tensordot(x, y, dims=([2, 1, 3], [1, 2, 0]))
self.assertEqualArray(
z, mx.array(np.tensordot(x, y, axes=([2, 1, 3], [1, 2, 0])))
)
for dtype in [mx.float16, mx.float32]:
with self.subTest(dtype=dtype):
self.assertCmpNumpy(
[(3, 4, 5), (4, 3, 2)],
mx.tensordot,
lambda x, y, dims: np.tensordot(x, y, axes=dims),
dtype=dtype,
dims=([1, 0], [0, 1]),
)
self.assertCmpNumpy(
[(3, 4, 5), (4, 5, 6)],
mx.tensordot,
lambda x, y, dims: np.tensordot(x, y, axes=dims),
dtype=dtype,
dims=2,
)
self.assertCmpNumpy(
[(3, 5, 4, 6), (6, 4, 5, 3)],
mx.tensordot,
lambda x, y, dims: np.tensordot(x, y, axes=dims),
dtype=dtype,
dims=([2, 1, 3], [1, 2, 0]),
)
if __name__ == "__main__":