simple numpy helper for tests (#352)

This commit is contained in:
Diogo 2024-01-03 22:19:19 -05:00 committed by GitHub
parent 526466dd09
commit 1ac18eac20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 74 additions and 76 deletions

View File

@ -2,6 +2,7 @@
import os import os
import unittest import unittest
from typing import Any, Callable, List, Tuple
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
@ -18,13 +19,33 @@ class MLXTestCase(unittest.TestCase):
def tearDown(self): def tearDown(self):
mx.set_default_device(self.default) mx.set_default_device(self.default)
def assertCmpNumpy(
self,
shape: List[Tuple[int] | Any],
mx_fn: Callable[..., mx.array],
np_fn: Callable[..., np.array],
atol=1e-2,
rtol=1e-2,
dtype=mx.float32,
**kwargs,
):
assert dtype != mx.bfloat16, "numpy does not support bfloat16"
args = [
mx.random.normal(s, dtype=dtype) if isinstance(s, Tuple) else s
for s in shape
]
mx_res = mx_fn(*args, **kwargs)
np_res = np_fn(
*[np.array(a) if isinstance(a, mx.array) else a for a in args], **kwargs
)
return self.assertEqualArray(mx_res, mx.array(np_res), atol=atol, rtol=rtol)
def assertEqualArray( def assertEqualArray(
self, self,
mx_res: mx.array, mx_res: mx.array,
expected: mx.array, expected: mx.array,
atol=1e-2, atol=1e-2,
rtol=1e-2, rtol=1e-2,
**kwargs,
): ):
assert tuple(mx_res.shape) == tuple( assert tuple(mx_res.shape) == tuple(
expected.shape expected.shape

View File

@ -324,24 +324,18 @@ class TestOps(mlx_tests.MLXTestCase):
def test_tri(self): def test_tri(self):
for shape in [[4], [4, 4], [2, 10]]: for shape in [[4], [4, 4], [2, 10]]:
for diag in [-1, 0, 1, -2]: for diag in [-1, 0, 1, -2]:
self.assertEqualArray( self.assertCmpNumpy(shape, mx.tri, np.tri, k=diag)
mx.tri(*shape, k=diag), mx.array(np.tri(*shape, k=diag))
)
def test_tril(self): def test_tril(self):
mt = mx.random.normal((10, 10))
nt = np.array(mt)
for diag in [-1, 0, 1, -2]: for diag in [-1, 0, 1, -2]:
self.assertEqualArray(mx.tril(mt, diag), mx.array(np.tril(nt, diag))) self.assertCmpNumpy([(10, 10)], mx.tril, np.tril, k=diag)
with self.assertRaises(Exception): with self.assertRaises(Exception):
mx.tril(mx.zeros((1))) mx.tril(mx.zeros((1)))
def test_triu(self): def test_triu(self):
mt = mx.random.normal((10, 10))
nt = np.array(mt)
for diag in [-1, 0, 1, -2]: for diag in [-1, 0, 1, -2]:
self.assertEqualArray(mx.triu(mt, diag), mx.array(np.triu(nt, diag))) self.assertCmpNumpy([(10, 10)], mx.triu, np.triu, k=diag)
with self.assertRaises(Exception): with self.assertRaises(Exception):
mx.triu(mx.zeros((1))) mx.triu(mx.zeros((1)))
@ -1260,20 +1254,17 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(a_bwd[4:-2, 2:-4], df[0]).item()) self.assertTrue(mx.allclose(a_bwd[4:-2, 2:-4], df[0]).item())
def test_where(self): def test_where(self):
a = mx.array([[1, 2], [3, 4]]) self.assertCmpNumpy([True, mx.array([[1, 2], [3, 4]]), 1], mx.where, np.where)
out = mx.where(True, a, 1) self.assertCmpNumpy([True, 1, mx.array([[1, 2], [3, 4]])], mx.where, np.where)
out_np = np.where(True, a, 1) self.assertCmpNumpy(
self.assertTrue(np.array_equal(out, out_np)) [
mx.array([[True, False], [False, True]]),
out = mx.where(True, 1, a) mx.array([[1, 2], [3, 4]]),
out_np = np.where(True, 1, a) mx.array([5, 6]),
self.assertTrue(np.array_equal(out, out_np)) ],
mx.where,
condition = mx.array([[True, False], [False, True]]) np.where,
b = mx.array([5, 6]) )
out = mx.where(condition, a, b)
out_np = np.where(condition, a, b)
self.assertTrue(np.array_equal(out, out_np))
def test_as_strided(self): def test_as_strided(self):
x_npy = np.random.randn(128).astype(np.float32) x_npy = np.random.randn(128).astype(np.float32)
@ -1408,24 +1399,13 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual((a + b)[0, 0].item(), 2) self.assertEqual((a + b)[0, 0].item(), 2)
def test_eye(self): def test_eye(self):
eye_matrix = mx.eye(3) self.assertCmpNumpy([3], mx.eye, np.eye)
np_eye_matrix = np.eye(3)
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
# Test for non-square matrix # Test for non-square matrix
eye_matrix = mx.eye(3, 4) self.assertCmpNumpy([3, 4], mx.eye, np.eye)
np_eye_matrix = np.eye(3, 4)
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
# Test with positive k parameter # Test with positive k parameter
eye_matrix = mx.eye(3, 4, k=1) self.assertCmpNumpy([3, 4], mx.eye, np.eye, k=1)
np_eye_matrix = np.eye(3, 4, k=1)
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
# Test with negative k parameter # Test with negative k parameter
eye_matrix = mx.eye(5, 6, k=-2) self.assertCmpNumpy([5, 6], mx.eye, np.eye, k=-2)
np_eye_matrix = np.eye(5, 6, k=-2)
self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix))
def test_stack(self): def test_stack(self):
a = mx.ones((2,)) a = mx.ones((2,))
@ -1518,49 +1498,46 @@ class TestOps(mlx_tests.MLXTestCase):
def test_repeat(self): def test_repeat(self):
# Setup data for the tests # Setup data for the tests
data = np.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]]) data = mx.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]])
# Test repeat along axis 0 # Test repeat along axis 0
repeat_axis_0 = mx.repeat(mx.array(data), 2, axis=0) self.assertCmpNumpy([data, 2], mx.repeat, np.repeat, axis=0)
expected_axis_0 = np.repeat(data, 2, axis=0)
self.assertEqualArray(repeat_axis_0, mx.array(expected_axis_0))
# Test repeat along axis 1 # Test repeat along axis 1
repeat_axis_1 = mx.repeat(mx.array(data), 2, axis=1) self.assertCmpNumpy([data, 2], mx.repeat, np.repeat, axis=1)
expected_axis_1 = np.repeat(data, 2, axis=1)
self.assertEqualArray(repeat_axis_1, mx.array(expected_axis_1))
# Test repeat along the last axis (default) # Test repeat along the last axis (default)
repeat_axis_2 = mx.repeat(mx.array(data), 2) self.assertCmpNumpy([data, 2], mx.repeat, np.repeat)
expected_axis_2 = np.repeat(data, 2)
self.assertEqualArray(repeat_axis_2, mx.array(expected_axis_2))
# Test repeat with a 1D array along axis 0 # Test repeat with a 1D array along axis 0
data_2 = mx.array([1, 3, 2]) self.assertCmpNumpy([mx.array([1, 3, 2]), 3], mx.repeat, np.repeat, axis=0)
repeat_2 = mx.repeat(mx.array(data_2), 3, axis=0)
expected_2 = np.repeat(data_2, 3)
self.assertEqualArray(repeat_2, mx.array(expected_2))
# Test repeat with a 2D array along axis 0 # Test repeat with a 2D array along axis 0
data_3 = mx.array([[1, 2, 3], [4, 5, 4], [0, 1, 2]]) self.assertCmpNumpy(
repeat_3 = mx.repeat(mx.array(data_3), 2, axis=0) [mx.array([[1, 2, 3], [4, 5, 4], [0, 1, 2]]), 2],
expected_3 = np.repeat(data_3, 2, axis=0) mx.repeat,
self.assertEqualArray(repeat_3, mx.array(expected_3)) np.repeat,
axis=0,
)
def test_tensordot(self): def test_tensordot(self):
x = mx.arange(60.0).reshape(3, 4, 5) for dtype in [mx.float16, mx.float32]:
y = mx.arange(24.0).reshape(4, 3, 2) with self.subTest(dtype=dtype):
z = mx.tensordot(x, y, dims=([1, 0], [0, 1])) self.assertCmpNumpy(
self.assertEqualArray(z, mx.array(np.tensordot(x, y, axes=([1, 0], [0, 1])))) [(3, 4, 5), (4, 3, 2)],
x = mx.random.normal((3, 4, 5)) mx.tensordot,
y = mx.random.normal((4, 5, 6)) lambda x, y, dims: np.tensordot(x, y, axes=dims),
z = mx.tensordot(x, y, dims=2) dtype=dtype,
self.assertEqualArray(z, mx.array(np.tensordot(x, y, axes=2))) dims=([1, 0], [0, 1]),
x = mx.random.normal((3, 5, 4, 6)) )
y = mx.random.normal((6, 4, 5, 3)) self.assertCmpNumpy(
z = mx.tensordot(x, y, dims=([2, 1, 3], [1, 2, 0])) [(3, 4, 5), (4, 5, 6)],
self.assertEqualArray( mx.tensordot,
z, mx.array(np.tensordot(x, y, axes=([2, 1, 3], [1, 2, 0]))) lambda x, y, dims: np.tensordot(x, y, axes=dims),
dtype=dtype,
dims=2,
)
self.assertCmpNumpy(
[(3, 5, 4, 6), (6, 4, 5, 3)],
mx.tensordot,
lambda x, y, dims: np.tensordot(x, y, axes=dims),
dtype=dtype,
dims=([2, 1, 3], [1, 2, 0]),
) )