Activations LeakyReLU / PReLU / Softplus / Mish (#109)

* Leaky_relu / prelu / softplus / mish

* added tests

* updated bench

* remove torch refs, add init to PReLU

* added arvix reference to mish

* added missing docs
This commit is contained in:
Diogo 2023-12-11 22:40:57 -05:00 committed by GitHub
parent f5df47ec6e
commit 02de234ef0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 133 additions and 31 deletions

View File

@ -96,7 +96,35 @@ def softmax_fused(axis, x):
def relu(x): def relu(x):
y = x y = x
for i in range(100): for i in range(100):
y = mx.maximum(y, 0) y = nn.relu(y)
mx.eval(y)
def leaky_relu(x: mx.array):
y = x
for i in range(100):
y = nn.leaky_relu(y)
mx.eval(y)
def prelu(x: mx.array):
y = x
for i in range(100):
y = nn.prelu(y, mx.ones(1))
mx.eval(y)
def softplus(x: mx.array):
y = x
for i in range(100):
y = nn.softplus(y)
mx.eval(y)
def mish(x: mx.array):
y = x
for i in range(100):
y = nn.mish(y)
mx.eval(y) mx.eval(y)
@ -334,24 +362,26 @@ if __name__ == "__main__":
elif args.benchmark == "relu": elif args.benchmark == "relu":
print(bench(relu, x)) print(bench(relu, x))
elif args.benchmark == "leaky_relu":
print(bench(leaky_relu, x))
elif args.benchmark == "elu": elif args.benchmark == "elu":
print(bench(elu, x)) print(bench(elu, x))
elif args.benchmark == "relu6": elif args.benchmark == "relu6":
print(bench(relu6, x)) print(bench(relu6, x))
elif args.benchmark == "softplus":
print(bench(softplus, x))
elif args.benchmark == "celu": elif args.benchmark == "celu":
print(bench(celu, x)) print(bench(celu, x))
elif args.benchmark == "log_sigmoid": elif args.benchmark == "log_sigmoid":
print(bench(log_sigmoid, x)) print(bench(log_sigmoid, x))
elif args.benchmark == "leaky_relu":
print(bench(leaky_relu, x))
elif args.benchmark == "prelu":
print(bench(prelu, x))
elif args.benchmark == "softplus":
print(bench(softplus, x))
elif args.benchmark == "mish":
print(bench(mish, x))
elif args.benchmark == "scalar_mul": elif args.benchmark == "scalar_mul":
print(bench(scalar_mult, x)) print(bench(scalar_mult, x))

View File

@ -163,6 +163,22 @@ def log_sigmoid(x):
sync_if_needed(x) sync_if_needed(x)
@torch.no_grad()
def prelu(x: torch.Tensor) -> torch.Tensor:
y = x
for _ in range(100):
y = torch.nn.functional.prelu(y, torch.ones(1).to(y.device))
sync_if_needed(x)
@torch.no_grad()
def mish(x: torch.Tensor) -> torch.Tensor:
y = x
for _ in range(100):
return torch.nn.functional.mish(y)
sync_if_needed(x)
@torch.no_grad() @torch.no_grad()
def scalar_mult(x): def scalar_mult(x):
y = x y = x
@ -376,6 +392,10 @@ if __name__ == "__main__":
elif args.benchmark == "log_sigmoid": elif args.benchmark == "log_sigmoid":
print(bench(log_sigmoid, x)) print(bench(log_sigmoid, x))
elif args.benchmark == "prelu":
print(bench(prelu, x))
elif args.benchmark == "mish":
print(bench(mish, x))
elif args.benchmark == "scalar_mul": elif args.benchmark == "scalar_mul":
print(bench(scalar_mult, x)) print(bench(scalar_mult, x))

View File

@ -209,6 +209,11 @@ if __name__ == "__main__":
compare_filtered("step --size 32x16x1024 --cpu") compare_filtered("step --size 32x16x1024 --cpu")
compare_filtered("selu --size 32x16x1024") compare_filtered("selu --size 32x16x1024")
compare_filtered("selu --size 32x16x1024 --cpu") compare_filtered("selu --size 32x16x1024 --cpu")
# compare_filtered("mish --size 32x16x1024") NOTE: Torch does not implement Mish in MPS atm
compare_filtered("mish --size 32x16x1024 --cpu")
compare_filtered("prelu --size 32x16x1024")
compare_filtered("prelu --size 32x16x1024 --cpu")
compare_filtered("scalar_mul --size 32x16x1024") compare_filtered("scalar_mul --size 32x16x1024")
compare_filtered("scalar_mul --size 32x16x1024 --cpu") compare_filtered("scalar_mul --size 32x16x1024 --cpu")
compare_filtered("cross_entropy --size 256x1024") compare_filtered("cross_entropy --size 256x1024")

View File

@ -146,10 +146,12 @@ Neural Network Layers
Embedding Embedding
ReLU ReLU
PReLU
GELU GELU
SiLU SiLU
Step Step
SELU SELU
Mish
Linear Linear
Conv1d Conv1d
Conv2d Conv2d
@ -171,9 +173,11 @@ simple functions.
gelu_approx gelu_approx
gelu_fast_approx gelu_fast_approx
relu relu
prelu
silu silu
step step
selu selu
mish
Loss Functions Loss Functions
-------------- --------------

View File

@ -7,6 +7,8 @@ from mlx.nn.layers.activations import (
SELU, SELU,
LeakyReLU, LeakyReLU,
LogSigmoid, LogSigmoid,
Mish,
PReLU,
ReLU, ReLU,
ReLU6, ReLU6,
SiLU, SiLU,
@ -19,6 +21,8 @@ from mlx.nn.layers.activations import (
gelu_fast_approx, gelu_fast_approx,
leaky_relu, leaky_relu,
log_sigmoid, log_sigmoid,
mish,
prelu,
relu, relu,
relu6, relu6,
selu, selu,

View File

@ -176,6 +176,33 @@ def selu(x):
See also :func:`elu`. See also :func:`elu`.
""" """
return elu(x, 1.67326) * 1.0507 return elu(x, 1.67326) * 1.0507
def prelu(x: mx.array, alpha: mx.array) -> mx.array:
r"""Applies the element-wise function:
.. math::
\text{PReLU}(x) = \max(0,x) + a * \min(0,x)
Here :math:`a` is an array.
"""
return mx.maximum(0, x) + alpha * mx.minimum(0, x)
def mish(x: mx.array) -> mx.array:
r"""Applies the Mish function, element-wise.
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
Reference: https://arxiv.org/abs/1908.08681
.. math::
\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
"""
return x * mx.tanh(softplus(x))
@_make_activation_module(mish)
class Mish(Module):
pass
@_make_activation_module(relu) @_make_activation_module(relu)
@ -257,6 +284,15 @@ class LogSigmoid(Module):
pass pass
class PReLU(Module):
def __init__(self, num_parameters=1, init=0.25):
super().__init__()
self.weight = mx.full([num_parameters], init)
def __call__(self, x: mx.array):
return prelu(x, self.weight)
class GELU(Module): class GELU(Module):
r"""Applies the Gaussian Error Linear Units. r"""Applies the Gaussian Error Linear Units.

View File

@ -2,8 +2,10 @@
import os import os
import unittest import unittest
from typing import Callable, List, Tuple
import mlx.core as mx import mlx.core as mx
import numpy as np
class MLXTestCase(unittest.TestCase): class MLXTestCase(unittest.TestCase):
@ -16,3 +18,16 @@ class MLXTestCase(unittest.TestCase):
def tearDown(self): def tearDown(self):
mx.set_default_device(self.default) mx.set_default_device(self.default)
def assertEqualArray(
self,
args: List[mx.array | float | int],
mlx_func: Callable[..., mx.array],
expected: mx.array,
atol=1e-2,
rtol=1e-2,
):
mx_res = mlx_func(*args)
assert tuple(mx_res.shape) == tuple(expected.shape), "shape mismatch"
assert mx_res.dtype == expected.dtype, "dtype mismatch"
np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol)

View File

@ -449,31 +449,19 @@ class TestNN(mlx_tests.MLXTestCase):
self.assertEqual(y.shape, [3]) self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32) self.assertEqual(y.dtype, mx.float32)
def test_step_activation(self): def test_prelu(self):
x = mx.arange(-3, 4) self.assertEqualArray(
expected = mx.array([0, 0, 0, 0, 0, 1, 1]) [mx.array([1.0, -1.0, 0.0, 0.5])],
y = nn.Step()(x) nn.PReLU(),
self.assertTrue(mx.array_equal(y, expected)) mx.array([1.0, -0.25, 0.0, 0.5]),
)
y = nn.Step(2)(x)
expected = mx.array([0, 0, 0, 0, 0, 0, 1]) def test_mish(self):
self.assertTrue(mx.array_equal(y, expected)) self.assertEqualArray(
[mx.array([1.0, -1.0, 0.0, 0.5])],
def test_selu(self): nn.Mish(),
x = mx.arange(-3, 4) mx.array([0.8651, -0.3034, 0.0000, 0.3752]),
expected = mx.array(
[
-1.670563817024231,
-1.5201621055603027,
-1.1113275289535522,
0.0,
1.0506999492645264,
2.1013998985290527,
3.152099847793579,
]
) )
y = nn.SELU()(x)
self.assertTrue(mx.allclose(y, expected))
if __name__ == "__main__": if __name__ == "__main__":