mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
added tri / tril / triu (#170)
* added tri / tril / triu * fixed tests * ctest tests * tri overload and simplified tests * changes from comment * more tests for m * ensure assert if not 2-D * remove broadcast_to * minor tweaks --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -2,7 +2,6 @@
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from typing import Callable, List, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
@@ -21,13 +20,16 @@ class MLXTestCase(unittest.TestCase):
|
||||
|
||||
def assertEqualArray(
|
||||
self,
|
||||
args: List[Union[mx.array, float, int]],
|
||||
mlx_func: Callable[..., mx.array],
|
||||
mx_res: mx.array,
|
||||
expected: mx.array,
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
**kwargs,
|
||||
):
|
||||
mx_res = mlx_func(*args)
|
||||
assert tuple(mx_res.shape) == tuple(expected.shape), "shape mismatch"
|
||||
assert mx_res.dtype == expected.dtype, "dtype mismatch"
|
||||
assert tuple(mx_res.shape) == tuple(
|
||||
expected.shape
|
||||
), f"shape mismatch expected={expected.shape} got={mx_res.shape}"
|
||||
assert (
|
||||
mx_res.dtype == expected.dtype
|
||||
), f"dtype mismatch expected={expected.dtype} got={mx_res.dtype}"
|
||||
np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol)
|
||||
|
@@ -451,15 +451,13 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
|
||||
def test_prelu(self):
|
||||
self.assertEqualArray(
|
||||
[mx.array([1.0, -1.0, 0.0, 0.5])],
|
||||
nn.PReLU(),
|
||||
nn.PReLU()(mx.array([1.0, -1.0, 0.0, 0.5])),
|
||||
mx.array([1.0, -0.25, 0.0, 0.5]),
|
||||
)
|
||||
|
||||
def test_mish(self):
|
||||
self.assertEqualArray(
|
||||
[mx.array([1.0, -1.0, 0.0, 0.5])],
|
||||
nn.Mish(),
|
||||
nn.Mish()(mx.array([1.0, -1.0, 0.0, 0.5])),
|
||||
mx.array([0.8651, -0.3034, 0.0000, 0.3752]),
|
||||
)
|
||||
|
||||
|
@@ -320,6 +320,30 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertFalse(mx.array_equal(x, y))
|
||||
self.assertTrue(mx.array_equal(x, y, equal_nan=True))
|
||||
|
||||
def test_tri(self):
|
||||
for shape in [[4], [4, 4], [2, 10]]:
|
||||
for diag in [-1, 0, 1, -2]:
|
||||
self.assertEqualArray(
|
||||
mx.tri(*shape, k=diag), mx.array(np.tri(*shape, k=diag))
|
||||
)
|
||||
|
||||
def test_tril(self):
|
||||
mt = mx.random.normal((10, 10))
|
||||
nt = np.array(mt)
|
||||
for diag in [-1, 0, 1, -2]:
|
||||
self.assertEqualArray(mx.tril(mt, diag), mx.array(np.tril(nt, diag)))
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
mx.tril(mx.zeros((1)))
|
||||
|
||||
def test_triu(self):
|
||||
mt = mx.random.normal((10, 10))
|
||||
nt = np.array(mt)
|
||||
for diag in [-1, 0, 1, -2]:
|
||||
self.assertEqualArray(mx.triu(mt, diag), mx.array(np.triu(nt, diag)))
|
||||
with self.assertRaises(Exception):
|
||||
mx.triu(mx.zeros((1)))
|
||||
|
||||
def test_minimum(self):
|
||||
x = mx.array([0.0, -5, 10.0])
|
||||
y = mx.array([1.0, -7.0, 3.0])
|
||||
|
Reference in New Issue
Block a user