added some tests for python bindings

This commit is contained in:
Gabrijel Boduljak 2023-12-22 00:33:36 +01:00 committed by Awni Hannun
parent fa096d64a2
commit 26bb16e768

View File

@ -0,0 +1,41 @@
# Copyright © 2023 Apple Inc.
import itertools
import unittest
import mlx.core as mx
import mlx_tests
import numpy as np
class TestLinalg(mlx_tests.MLXTestCase):
def test_norm(self):
def check_mx_np(a_mx, a_np):
self.assertTrue(np.allclose(a_np, a_mx, atol=1e-5, rtol=1e-6))
x_mx = mx.arange(18).reshape((2, 3, 3))
x_np = np.arange(18).reshape((2, 3, 3))
for num_axes in range(1, 3):
for axis in itertools.combinations(range(3), num_axes):
if num_axes == 1:
ords = [None, 0.5, 0, 1, 2, 3, -1, 1]
else:
ords = [None, "fro", -1, 1]
for o in ords:
for keepdims in [True, False]:
if o:
out_np = np.linalg.norm(
x_np, ord=o, axis=axis, keepdims=keepdims
)
out_mx = mx.linalg.norm(
x_mx, ord=o, axis=axis, keepdims=keepdims
)
else:
out_np = np.linalg.norm(x_np, axis=axis, keepdims=keepdims)
out_mx = mx.linalg.norm(x_mx, axis=axis, keepdims=keepdims)
assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
if __name__ == "__main__":
unittest.main()