mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 16:48:10 +08:00
added tri / tril / triu (#170)
* added tri / tril / triu * fixed tests * ctest tests * tri overload and simplified tests * changes from comment * more tests for m * ensure assert if not 2-D * remove broadcast_to * minor tweaks --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -451,15 +451,13 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
|
||||
def test_prelu(self):
|
||||
self.assertEqualArray(
|
||||
[mx.array([1.0, -1.0, 0.0, 0.5])],
|
||||
nn.PReLU(),
|
||||
nn.PReLU()(mx.array([1.0, -1.0, 0.0, 0.5])),
|
||||
mx.array([1.0, -0.25, 0.0, 0.5]),
|
||||
)
|
||||
|
||||
def test_mish(self):
|
||||
self.assertEqualArray(
|
||||
[mx.array([1.0, -1.0, 0.0, 0.5])],
|
||||
nn.Mish(),
|
||||
nn.Mish()(mx.array([1.0, -1.0, 0.0, 0.5])),
|
||||
mx.array([0.8651, -0.3034, 0.0000, 0.3752]),
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user