added tri / tril / triu (#170)

* added tri / tril / triu

* fixed tests

* ctest tests

* tri overload and simplified tests

* changes from comment

* more tests for m

* ensure assert if not 2-D

* remove broadcast_to

* minor tweaks

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Diogo
2023-12-15 20:30:34 -05:00
committed by GitHub
parent 2e02acdc83
commit dc2edc762c
9 changed files with 207 additions and 12 deletions

View File

@@ -2,7 +2,6 @@
import os
import unittest
from typing import Callable, List, Tuple, Union
import mlx.core as mx
import numpy as np
@@ -21,13 +20,16 @@ class MLXTestCase(unittest.TestCase):
def assertEqualArray(
self,
args: List[Union[mx.array, float, int]],
mlx_func: Callable[..., mx.array],
mx_res: mx.array,
expected: mx.array,
atol=1e-2,
rtol=1e-2,
**kwargs,
):
mx_res = mlx_func(*args)
assert tuple(mx_res.shape) == tuple(expected.shape), "shape mismatch"
assert mx_res.dtype == expected.dtype, "dtype mismatch"
assert tuple(mx_res.shape) == tuple(
expected.shape
), f"shape mismatch expected={expected.shape} got={mx_res.shape}"
assert (
mx_res.dtype == expected.dtype
), f"dtype mismatch expected={expected.dtype} got={mx_res.dtype}"
np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol)