added tri / tril / triu (#170)

* added tri / tril / triu

* fixed tests

* ctest tests

* tri overload and simplified tests

* changes from comment

* more tests for m

* ensure assert if not 2-D

* remove broadcast_to

* minor tweaks

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Diogo
2023-12-15 20:30:34 -05:00
committed by GitHub
parent 2e02acdc83
commit dc2edc762c
9 changed files with 207 additions and 12 deletions

View File

@@ -320,6 +320,30 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertFalse(mx.array_equal(x, y))
self.assertTrue(mx.array_equal(x, y, equal_nan=True))
def test_tri(self):
for shape in [[4], [4, 4], [2, 10]]:
for diag in [-1, 0, 1, -2]:
self.assertEqualArray(
mx.tri(*shape, k=diag), mx.array(np.tri(*shape, k=diag))
)
def test_tril(self):
mt = mx.random.normal((10, 10))
nt = np.array(mt)
for diag in [-1, 0, 1, -2]:
self.assertEqualArray(mx.tril(mt, diag), mx.array(np.tril(nt, diag)))
with self.assertRaises(Exception):
mx.tril(mx.zeros((1)))
def test_triu(self):
mt = mx.random.normal((10, 10))
nt = np.array(mt)
for diag in [-1, 0, 1, -2]:
self.assertEqualArray(mx.triu(mt, diag), mx.array(np.triu(nt, diag)))
with self.assertRaises(Exception):
mx.triu(mx.zeros((1)))
def test_minimum(self):
x = mx.array([0.0, -5, 10.0])
y = mx.array([1.0, -7.0, 3.0])