diff --git a/benchmarks/python/comparative/bench_mlx.py b/benchmarks/python/comparative/bench_mlx.py index 73ea2b423..71a3d6c01 100644 --- a/benchmarks/python/comparative/bench_mlx.py +++ b/benchmarks/python/comparative/bench_mlx.py @@ -96,7 +96,35 @@ def softmax_fused(axis, x): def relu(x): y = x for i in range(100): - y = mx.maximum(y, 0) + y = nn.relu(y) + mx.eval(y) + + +def leaky_relu(x: mx.array): + y = x + for i in range(100): + y = nn.leaky_relu(y) + mx.eval(y) + + +def prelu(x: mx.array): + y = x + for i in range(100): + y = nn.prelu(y, mx.ones(1)) + mx.eval(y) + + +def softplus(x: mx.array): + y = x + for i in range(100): + y = nn.softplus(y) + mx.eval(y) + + +def mish(x: mx.array): + y = x + for i in range(100): + y = nn.mish(y) mx.eval(y) @@ -334,24 +362,26 @@ if __name__ == "__main__": elif args.benchmark == "relu": print(bench(relu, x)) - elif args.benchmark == "leaky_relu": - print(bench(leaky_relu, x)) - elif args.benchmark == "elu": print(bench(elu, x)) elif args.benchmark == "relu6": print(bench(relu6, x)) - elif args.benchmark == "softplus": - print(bench(softplus, x)) - elif args.benchmark == "celu": print(bench(celu, x)) elif args.benchmark == "log_sigmoid": print(bench(log_sigmoid, x)) + elif args.benchmark == "leaky_relu": + print(bench(leaky_relu, x)) + elif args.benchmark == "prelu": + print(bench(prelu, x)) + elif args.benchmark == "softplus": + print(bench(softplus, x)) + elif args.benchmark == "mish": + print(bench(mish, x)) elif args.benchmark == "scalar_mul": print(bench(scalar_mult, x)) diff --git a/benchmarks/python/comparative/bench_torch.py b/benchmarks/python/comparative/bench_torch.py index ed31e0119..d556dd051 100644 --- a/benchmarks/python/comparative/bench_torch.py +++ b/benchmarks/python/comparative/bench_torch.py @@ -163,6 +163,22 @@ def log_sigmoid(x): sync_if_needed(x) +@torch.no_grad() +def prelu(x: torch.Tensor) -> torch.Tensor: + y = x + for _ in range(100): + y = torch.nn.functional.prelu(y, torch.ones(1).to(y.device)) + sync_if_needed(x) + + +@torch.no_grad() +def mish(x: torch.Tensor) -> torch.Tensor: + y = x + for _ in range(100): + return torch.nn.functional.mish(y) + sync_if_needed(x) + + @torch.no_grad() def scalar_mult(x): y = x @@ -376,6 +392,10 @@ if __name__ == "__main__": elif args.benchmark == "log_sigmoid": print(bench(log_sigmoid, x)) + elif args.benchmark == "prelu": + print(bench(prelu, x)) + elif args.benchmark == "mish": + print(bench(mish, x)) elif args.benchmark == "scalar_mul": print(bench(scalar_mult, x)) diff --git a/benchmarks/python/comparative/compare.py b/benchmarks/python/comparative/compare.py index a81d768e7..c54af3a46 100644 --- a/benchmarks/python/comparative/compare.py +++ b/benchmarks/python/comparative/compare.py @@ -209,6 +209,11 @@ if __name__ == "__main__": compare_filtered("step --size 32x16x1024 --cpu") compare_filtered("selu --size 32x16x1024") compare_filtered("selu --size 32x16x1024 --cpu") + # compare_filtered("mish --size 32x16x1024") NOTE: Torch does not implement Mish in MPS atm + compare_filtered("mish --size 32x16x1024 --cpu") + compare_filtered("prelu --size 32x16x1024") + compare_filtered("prelu --size 32x16x1024 --cpu") + compare_filtered("scalar_mul --size 32x16x1024") compare_filtered("scalar_mul --size 32x16x1024 --cpu") compare_filtered("cross_entropy --size 256x1024") diff --git a/docs/src/python/nn.rst b/docs/src/python/nn.rst index 478bba79b..93cfd8c78 100644 --- a/docs/src/python/nn.rst +++ b/docs/src/python/nn.rst @@ -146,10 +146,12 @@ Neural Network Layers Embedding ReLU + PReLU GELU SiLU Step SELU + Mish Linear Conv1d Conv2d @@ -171,9 +173,11 @@ simple functions. gelu_approx gelu_fast_approx relu + prelu silu step selu + mish Loss Functions -------------- diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 544753d97..04557843b 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -7,6 +7,8 @@ from mlx.nn.layers.activations import ( SELU, LeakyReLU, LogSigmoid, + Mish, + PReLU, ReLU, ReLU6, SiLU, @@ -19,6 +21,8 @@ from mlx.nn.layers.activations import ( gelu_fast_approx, leaky_relu, log_sigmoid, + mish, + prelu, relu, relu6, selu, diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index c62694288..88b3cc2b5 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -176,6 +176,33 @@ def selu(x): See also :func:`elu`. """ return elu(x, 1.67326) * 1.0507 +def prelu(x: mx.array, alpha: mx.array) -> mx.array: + r"""Applies the element-wise function: + + .. math:: + \text{PReLU}(x) = \max(0,x) + a * \min(0,x) + + Here :math:`a` is an array. + """ + return mx.maximum(0, x) + alpha * mx.minimum(0, x) + + +def mish(x: mx.array) -> mx.array: + r"""Applies the Mish function, element-wise. + Mish: A Self Regularized Non-Monotonic Neural Activation Function. + + Reference: https://arxiv.org/abs/1908.08681 + + .. math:: + \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x)) + + """ + return x * mx.tanh(softplus(x)) + + +@_make_activation_module(mish) +class Mish(Module): + pass @_make_activation_module(relu) @@ -257,6 +284,15 @@ class LogSigmoid(Module): pass +class PReLU(Module): + def __init__(self, num_parameters=1, init=0.25): + super().__init__() + self.weight = mx.full([num_parameters], init) + + def __call__(self, x: mx.array): + return prelu(x, self.weight) + + class GELU(Module): r"""Applies the Gaussian Error Linear Units. diff --git a/python/tests/mlx_tests.py b/python/tests/mlx_tests.py index 7bbadc9cb..9400950a1 100644 --- a/python/tests/mlx_tests.py +++ b/python/tests/mlx_tests.py @@ -2,8 +2,10 @@ import os import unittest +from typing import Callable, List, Tuple import mlx.core as mx +import numpy as np class MLXTestCase(unittest.TestCase): @@ -16,3 +18,16 @@ class MLXTestCase(unittest.TestCase): def tearDown(self): mx.set_default_device(self.default) + + def assertEqualArray( + self, + args: List[mx.array | float | int], + mlx_func: Callable[..., mx.array], + expected: mx.array, + atol=1e-2, + rtol=1e-2, + ): + mx_res = mlx_func(*args) + assert tuple(mx_res.shape) == tuple(expected.shape), "shape mismatch" + assert mx_res.dtype == expected.dtype, "dtype mismatch" + np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index ef3386bc1..f5597474d 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -449,31 +449,19 @@ class TestNN(mlx_tests.MLXTestCase): self.assertEqual(y.shape, [3]) self.assertEqual(y.dtype, mx.float32) - def test_step_activation(self): - x = mx.arange(-3, 4) - expected = mx.array([0, 0, 0, 0, 0, 1, 1]) - y = nn.Step()(x) - self.assertTrue(mx.array_equal(y, expected)) - - y = nn.Step(2)(x) - expected = mx.array([0, 0, 0, 0, 0, 0, 1]) - self.assertTrue(mx.array_equal(y, expected)) - - def test_selu(self): - x = mx.arange(-3, 4) - expected = mx.array( - [ - -1.670563817024231, - -1.5201621055603027, - -1.1113275289535522, - 0.0, - 1.0506999492645264, - 2.1013998985290527, - 3.152099847793579, - ] + def test_prelu(self): + self.assertEqualArray( + [mx.array([1.0, -1.0, 0.0, 0.5])], + nn.PReLU(), + mx.array([1.0, -0.25, 0.0, 0.5]), + ) + + def test_mish(self): + self.assertEqualArray( + [mx.array([1.0, -1.0, 0.0, 0.5])], + nn.Mish(), + mx.array([0.8651, -0.3034, 0.0000, 0.3752]), ) - y = nn.SELU()(x) - self.assertTrue(mx.allclose(y, expected)) if __name__ == "__main__":