mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-27 00:09:17 +08:00
Activations LeakyReLU / PReLU / Softplus / Mish (#109)
* Leaky_relu / prelu / softplus / mish * added tests * updated bench * remove torch refs, add init to PReLU * added arvix reference to mish * added missing docs
This commit is contained in:
parent
f5df47ec6e
commit
02de234ef0
@ -96,7 +96,35 @@ def softmax_fused(axis, x):
|
|||||||
def relu(x):
|
def relu(x):
|
||||||
y = x
|
y = x
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
y = mx.maximum(y, 0)
|
y = nn.relu(y)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def leaky_relu(x: mx.array):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = nn.leaky_relu(y)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def prelu(x: mx.array):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = nn.prelu(y, mx.ones(1))
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def softplus(x: mx.array):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = nn.softplus(y)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def mish(x: mx.array):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = nn.mish(y)
|
||||||
mx.eval(y)
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
@ -334,24 +362,26 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "relu":
|
elif args.benchmark == "relu":
|
||||||
print(bench(relu, x))
|
print(bench(relu, x))
|
||||||
|
|
||||||
elif args.benchmark == "leaky_relu":
|
|
||||||
print(bench(leaky_relu, x))
|
|
||||||
|
|
||||||
elif args.benchmark == "elu":
|
elif args.benchmark == "elu":
|
||||||
print(bench(elu, x))
|
print(bench(elu, x))
|
||||||
|
|
||||||
elif args.benchmark == "relu6":
|
elif args.benchmark == "relu6":
|
||||||
print(bench(relu6, x))
|
print(bench(relu6, x))
|
||||||
|
|
||||||
elif args.benchmark == "softplus":
|
|
||||||
print(bench(softplus, x))
|
|
||||||
|
|
||||||
elif args.benchmark == "celu":
|
elif args.benchmark == "celu":
|
||||||
print(bench(celu, x))
|
print(bench(celu, x))
|
||||||
|
|
||||||
elif args.benchmark == "log_sigmoid":
|
elif args.benchmark == "log_sigmoid":
|
||||||
print(bench(log_sigmoid, x))
|
print(bench(log_sigmoid, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "leaky_relu":
|
||||||
|
print(bench(leaky_relu, x))
|
||||||
|
elif args.benchmark == "prelu":
|
||||||
|
print(bench(prelu, x))
|
||||||
|
elif args.benchmark == "softplus":
|
||||||
|
print(bench(softplus, x))
|
||||||
|
elif args.benchmark == "mish":
|
||||||
|
print(bench(mish, x))
|
||||||
elif args.benchmark == "scalar_mul":
|
elif args.benchmark == "scalar_mul":
|
||||||
print(bench(scalar_mult, x))
|
print(bench(scalar_mult, x))
|
||||||
|
|
||||||
|
@ -163,6 +163,22 @@ def log_sigmoid(x):
|
|||||||
sync_if_needed(x)
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def prelu(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
y = x
|
||||||
|
for _ in range(100):
|
||||||
|
y = torch.nn.functional.prelu(y, torch.ones(1).to(y.device))
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def mish(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
y = x
|
||||||
|
for _ in range(100):
|
||||||
|
return torch.nn.functional.mish(y)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def scalar_mult(x):
|
def scalar_mult(x):
|
||||||
y = x
|
y = x
|
||||||
@ -376,6 +392,10 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "log_sigmoid":
|
elif args.benchmark == "log_sigmoid":
|
||||||
print(bench(log_sigmoid, x))
|
print(bench(log_sigmoid, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "prelu":
|
||||||
|
print(bench(prelu, x))
|
||||||
|
elif args.benchmark == "mish":
|
||||||
|
print(bench(mish, x))
|
||||||
elif args.benchmark == "scalar_mul":
|
elif args.benchmark == "scalar_mul":
|
||||||
print(bench(scalar_mult, x))
|
print(bench(scalar_mult, x))
|
||||||
|
|
||||||
|
@ -209,6 +209,11 @@ if __name__ == "__main__":
|
|||||||
compare_filtered("step --size 32x16x1024 --cpu")
|
compare_filtered("step --size 32x16x1024 --cpu")
|
||||||
compare_filtered("selu --size 32x16x1024")
|
compare_filtered("selu --size 32x16x1024")
|
||||||
compare_filtered("selu --size 32x16x1024 --cpu")
|
compare_filtered("selu --size 32x16x1024 --cpu")
|
||||||
|
# compare_filtered("mish --size 32x16x1024") NOTE: Torch does not implement Mish in MPS atm
|
||||||
|
compare_filtered("mish --size 32x16x1024 --cpu")
|
||||||
|
compare_filtered("prelu --size 32x16x1024")
|
||||||
|
compare_filtered("prelu --size 32x16x1024 --cpu")
|
||||||
|
|
||||||
compare_filtered("scalar_mul --size 32x16x1024")
|
compare_filtered("scalar_mul --size 32x16x1024")
|
||||||
compare_filtered("scalar_mul --size 32x16x1024 --cpu")
|
compare_filtered("scalar_mul --size 32x16x1024 --cpu")
|
||||||
compare_filtered("cross_entropy --size 256x1024")
|
compare_filtered("cross_entropy --size 256x1024")
|
||||||
|
@ -146,10 +146,12 @@ Neural Network Layers
|
|||||||
|
|
||||||
Embedding
|
Embedding
|
||||||
ReLU
|
ReLU
|
||||||
|
PReLU
|
||||||
GELU
|
GELU
|
||||||
SiLU
|
SiLU
|
||||||
Step
|
Step
|
||||||
SELU
|
SELU
|
||||||
|
Mish
|
||||||
Linear
|
Linear
|
||||||
Conv1d
|
Conv1d
|
||||||
Conv2d
|
Conv2d
|
||||||
@ -171,9 +173,11 @@ simple functions.
|
|||||||
gelu_approx
|
gelu_approx
|
||||||
gelu_fast_approx
|
gelu_fast_approx
|
||||||
relu
|
relu
|
||||||
|
prelu
|
||||||
silu
|
silu
|
||||||
step
|
step
|
||||||
selu
|
selu
|
||||||
|
mish
|
||||||
|
|
||||||
Loss Functions
|
Loss Functions
|
||||||
--------------
|
--------------
|
||||||
|
@ -7,6 +7,8 @@ from mlx.nn.layers.activations import (
|
|||||||
SELU,
|
SELU,
|
||||||
LeakyReLU,
|
LeakyReLU,
|
||||||
LogSigmoid,
|
LogSigmoid,
|
||||||
|
Mish,
|
||||||
|
PReLU,
|
||||||
ReLU,
|
ReLU,
|
||||||
ReLU6,
|
ReLU6,
|
||||||
SiLU,
|
SiLU,
|
||||||
@ -19,6 +21,8 @@ from mlx.nn.layers.activations import (
|
|||||||
gelu_fast_approx,
|
gelu_fast_approx,
|
||||||
leaky_relu,
|
leaky_relu,
|
||||||
log_sigmoid,
|
log_sigmoid,
|
||||||
|
mish,
|
||||||
|
prelu,
|
||||||
relu,
|
relu,
|
||||||
relu6,
|
relu6,
|
||||||
selu,
|
selu,
|
||||||
|
@ -176,6 +176,33 @@ def selu(x):
|
|||||||
See also :func:`elu`.
|
See also :func:`elu`.
|
||||||
"""
|
"""
|
||||||
return elu(x, 1.67326) * 1.0507
|
return elu(x, 1.67326) * 1.0507
|
||||||
|
def prelu(x: mx.array, alpha: mx.array) -> mx.array:
|
||||||
|
r"""Applies the element-wise function:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{PReLU}(x) = \max(0,x) + a * \min(0,x)
|
||||||
|
|
||||||
|
Here :math:`a` is an array.
|
||||||
|
"""
|
||||||
|
return mx.maximum(0, x) + alpha * mx.minimum(0, x)
|
||||||
|
|
||||||
|
|
||||||
|
def mish(x: mx.array) -> mx.array:
|
||||||
|
r"""Applies the Mish function, element-wise.
|
||||||
|
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
|
||||||
|
|
||||||
|
Reference: https://arxiv.org/abs/1908.08681
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
|
||||||
|
|
||||||
|
"""
|
||||||
|
return x * mx.tanh(softplus(x))
|
||||||
|
|
||||||
|
|
||||||
|
@_make_activation_module(mish)
|
||||||
|
class Mish(Module):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@_make_activation_module(relu)
|
@_make_activation_module(relu)
|
||||||
@ -257,6 +284,15 @@ class LogSigmoid(Module):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PReLU(Module):
|
||||||
|
def __init__(self, num_parameters=1, init=0.25):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = mx.full([num_parameters], init)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array):
|
||||||
|
return prelu(x, self.weight)
|
||||||
|
|
||||||
|
|
||||||
class GELU(Module):
|
class GELU(Module):
|
||||||
r"""Applies the Gaussian Error Linear Units.
|
r"""Applies the Gaussian Error Linear Units.
|
||||||
|
|
||||||
|
@ -2,8 +2,10 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
from typing import Callable, List, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class MLXTestCase(unittest.TestCase):
|
class MLXTestCase(unittest.TestCase):
|
||||||
@ -16,3 +18,16 @@ class MLXTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
mx.set_default_device(self.default)
|
mx.set_default_device(self.default)
|
||||||
|
|
||||||
|
def assertEqualArray(
|
||||||
|
self,
|
||||||
|
args: List[mx.array | float | int],
|
||||||
|
mlx_func: Callable[..., mx.array],
|
||||||
|
expected: mx.array,
|
||||||
|
atol=1e-2,
|
||||||
|
rtol=1e-2,
|
||||||
|
):
|
||||||
|
mx_res = mlx_func(*args)
|
||||||
|
assert tuple(mx_res.shape) == tuple(expected.shape), "shape mismatch"
|
||||||
|
assert mx_res.dtype == expected.dtype, "dtype mismatch"
|
||||||
|
np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol)
|
||||||
|
@ -449,31 +449,19 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(y.shape, [3])
|
self.assertEqual(y.shape, [3])
|
||||||
self.assertEqual(y.dtype, mx.float32)
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
def test_step_activation(self):
|
def test_prelu(self):
|
||||||
x = mx.arange(-3, 4)
|
self.assertEqualArray(
|
||||||
expected = mx.array([0, 0, 0, 0, 0, 1, 1])
|
[mx.array([1.0, -1.0, 0.0, 0.5])],
|
||||||
y = nn.Step()(x)
|
nn.PReLU(),
|
||||||
self.assertTrue(mx.array_equal(y, expected))
|
mx.array([1.0, -0.25, 0.0, 0.5]),
|
||||||
|
)
|
||||||
y = nn.Step(2)(x)
|
|
||||||
expected = mx.array([0, 0, 0, 0, 0, 0, 1])
|
def test_mish(self):
|
||||||
self.assertTrue(mx.array_equal(y, expected))
|
self.assertEqualArray(
|
||||||
|
[mx.array([1.0, -1.0, 0.0, 0.5])],
|
||||||
def test_selu(self):
|
nn.Mish(),
|
||||||
x = mx.arange(-3, 4)
|
mx.array([0.8651, -0.3034, 0.0000, 0.3752]),
|
||||||
expected = mx.array(
|
|
||||||
[
|
|
||||||
-1.670563817024231,
|
|
||||||
-1.5201621055603027,
|
|
||||||
-1.1113275289535522,
|
|
||||||
0.0,
|
|
||||||
1.0506999492645264,
|
|
||||||
2.1013998985290527,
|
|
||||||
3.152099847793579,
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
y = nn.SELU()(x)
|
|
||||||
self.assertTrue(mx.allclose(y, expected))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user