Activations LeakyReLU / PReLU / Softplus / Mish (#109)

* Leaky_relu / prelu / softplus / mish

* added tests

* updated bench

* remove torch refs, add init to PReLU

* added arvix reference to mish

* added missing docs
This commit is contained in:
Diogo
2023-12-11 22:40:57 -05:00
committed by GitHub
parent f5df47ec6e
commit 02de234ef0
8 changed files with 133 additions and 31 deletions

View File

@@ -2,8 +2,10 @@
import os
import unittest
from typing import Callable, List, Tuple
import mlx.core as mx
import numpy as np
class MLXTestCase(unittest.TestCase):
@@ -16,3 +18,16 @@ class MLXTestCase(unittest.TestCase):
def tearDown(self):
mx.set_default_device(self.default)
def assertEqualArray(
self,
args: List[mx.array | float | int],
mlx_func: Callable[..., mx.array],
expected: mx.array,
atol=1e-2,
rtol=1e-2,
):
mx_res = mlx_func(*args)
assert tuple(mx_res.shape) == tuple(expected.shape), "shape mismatch"
assert mx_res.dtype == expected.dtype, "dtype mismatch"
np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol)