more tests

This commit is contained in:
Gabrijel Boduljak 2023-12-22 05:34:06 +01:00 committed by Awni Hannun
parent f82ab0eec9
commit 5a184d5b5d

View File

@ -42,11 +42,27 @@ class TestLinalg(mlx_tests.MLXTestCase):
) )
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
# Test when no axes and no ords are provided # Test only axis provided
for keepdims in [True, False]: for shape in [(3,), (2, 3), (2, 3, 3)]:
out_np = np.linalg.norm(x_np, keepdims=keepdims) x_mx = mx.arange(math.prod(shape)).reshape(shape)
out_mx = mx.linalg.norm(x_mx, keepdims=keepdims) x_np = np.arange(math.prod(shape)).reshape(shape)
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
for num_axes in range(1, len(shape)):
for axis in itertools.combinations(range(len(shape)), num_axes):
for keepdims in [True, False]:
out_np = np.linalg.norm(x_np, axis=axis, keepdims=keepdims)
out_mx = mx.linalg.norm(x_mx, axis=axis, keepdims=keepdims)
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
# Test only ord provided
for shape in [(3,), (2, 3)]:
x_mx = mx.arange(math.prod(shape)).reshape(shape)
x_np = np.arange(math.prod(shape)).reshape(shape)
for o in [None, 1, -1, float("inf"), -float("inf")]:
for keepdims in [True, False]:
out_np = np.linalg.norm(x_np, ord=o, keepdims=keepdims)
out_mx = mx.linalg.norm(x_mx, ord=o, keepdims=keepdims)
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
if __name__ == "__main__": if __name__ == "__main__":