mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
Activations LeakyReLU / PReLU / Softplus / Mish (#109)
* Leaky_relu / prelu / softplus / mish * added tests * updated bench * remove torch refs, add init to PReLU * added arvix reference to mish * added missing docs
This commit is contained in:
@@ -2,8 +2,10 @@
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MLXTestCase(unittest.TestCase):
|
||||
@@ -16,3 +18,16 @@ class MLXTestCase(unittest.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
mx.set_default_device(self.default)
|
||||
|
||||
def assertEqualArray(
|
||||
self,
|
||||
args: List[mx.array | float | int],
|
||||
mlx_func: Callable[..., mx.array],
|
||||
expected: mx.array,
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
):
|
||||
mx_res = mlx_func(*args)
|
||||
assert tuple(mx_res.shape) == tuple(expected.shape), "shape mismatch"
|
||||
assert mx_res.dtype == expected.dtype, "dtype mismatch"
|
||||
np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol)
|
||||
|
@@ -449,31 +449,19 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_step_activation(self):
|
||||
x = mx.arange(-3, 4)
|
||||
expected = mx.array([0, 0, 0, 0, 0, 1, 1])
|
||||
y = nn.Step()(x)
|
||||
self.assertTrue(mx.array_equal(y, expected))
|
||||
|
||||
y = nn.Step(2)(x)
|
||||
expected = mx.array([0, 0, 0, 0, 0, 0, 1])
|
||||
self.assertTrue(mx.array_equal(y, expected))
|
||||
|
||||
def test_selu(self):
|
||||
x = mx.arange(-3, 4)
|
||||
expected = mx.array(
|
||||
[
|
||||
-1.670563817024231,
|
||||
-1.5201621055603027,
|
||||
-1.1113275289535522,
|
||||
0.0,
|
||||
1.0506999492645264,
|
||||
2.1013998985290527,
|
||||
3.152099847793579,
|
||||
]
|
||||
def test_prelu(self):
|
||||
self.assertEqualArray(
|
||||
[mx.array([1.0, -1.0, 0.0, 0.5])],
|
||||
nn.PReLU(),
|
||||
mx.array([1.0, -0.25, 0.0, 0.5]),
|
||||
)
|
||||
|
||||
def test_mish(self):
|
||||
self.assertEqualArray(
|
||||
[mx.array([1.0, -1.0, 0.0, 0.5])],
|
||||
nn.Mish(),
|
||||
mx.array([0.8651, -0.3034, 0.0000, 0.3752]),
|
||||
)
|
||||
y = nn.SELU()(x)
|
||||
self.assertTrue(mx.allclose(y, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user