diff --git a/benchmarks/python/comparative/bench_mlx.py b/benchmarks/python/comparative/bench_mlx.py index d73610b17..73ea2b423 100644 --- a/benchmarks/python/comparative/bench_mlx.py +++ b/benchmarks/python/comparative/bench_mlx.py @@ -223,6 +223,20 @@ def topk(axis, x): mx.eval(ys) +def step_function(x): + y = x + for i in range(100): + y = nn.step(x) + mx.eval(y) + + +def selu(x): + y = x + for i in range(100): + y = nn.selu(x) + mx.eval(y) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("benchmark", help="Choose the benchmark to run") @@ -372,5 +386,11 @@ if __name__ == "__main__": elif args.benchmark == "topk": print(bench(topk, axis, x)) + elif args.benchmark == "step": + print(bench(step_function, x)) + + elif args.benchmark == "selu": + print(bench(selu, x)) + else: raise ValueError("Unknown benchmark") diff --git a/benchmarks/python/comparative/bench_torch.py b/benchmarks/python/comparative/bench_torch.py index 1c92bd7a1..ed31e0119 100644 --- a/benchmarks/python/comparative/bench_torch.py +++ b/benchmarks/python/comparative/bench_torch.py @@ -257,6 +257,14 @@ def topk(axis, x): sync_if_needed(x) +@torch.no_grad() +def selu(x): + y = x + for i in range(100): + y = torch.nn.functional.selu(y) + sync_if_needed(x) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("benchmark", help="Choose the benchmark to run") diff --git a/benchmarks/python/comparative/compare.py b/benchmarks/python/comparative/compare.py index 3e64625e5..a81d768e7 100644 --- a/benchmarks/python/comparative/compare.py +++ b/benchmarks/python/comparative/compare.py @@ -205,6 +205,10 @@ if __name__ == "__main__": compare_filtered("celu --size 32x16x1024 --cpu") compare_filtered("log_sigmoid --size 32x16x1024") compare_filtered("log_sigmoid --size 32x16x1024 --cpu") + compare_filtered("step --size 32x16x1024") + compare_filtered("step --size 32x16x1024 --cpu") + compare_filtered("selu --size 32x16x1024") + compare_filtered("selu --size 32x16x1024 --cpu") compare_filtered("scalar_mul --size 32x16x1024") compare_filtered("scalar_mul --size 32x16x1024 --cpu") compare_filtered("cross_entropy --size 256x1024") diff --git a/docs/src/python/nn.rst b/docs/src/python/nn.rst index 15497206a..478bba79b 100644 --- a/docs/src/python/nn.rst +++ b/docs/src/python/nn.rst @@ -97,7 +97,7 @@ Updating the parameters MLX modules allow accessing and updating individual parameters. However, most times we need to update large subsets of a module's parameters. This action is -performed by :meth:`Module.update`. +performed by :meth:`Module.update`. Value and grad -------------- @@ -148,6 +148,8 @@ Neural Network Layers ReLU GELU SiLU + Step + SELU Linear Conv1d Conv2d @@ -170,6 +172,8 @@ simple functions. gelu_fast_approx relu silu + step + selu Loss Functions -------------- diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index e78f161a8..544753d97 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -4,12 +4,14 @@ from mlx.nn.layers.activations import ( CELU, ELU, GELU, + SELU, LeakyReLU, LogSigmoid, ReLU, ReLU6, SiLU, Softplus, + Step, celu, elu, gelu, @@ -19,8 +21,10 @@ from mlx.nn.layers.activations import ( log_sigmoid, relu, relu6, + selu, silu, softplus, + step, ) from mlx.nn.layers.base import Module from mlx.nn.layers.containers import Sequential diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index f924626a8..c62694288 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -74,7 +74,7 @@ def celu(x, alpha=1.0): def silu(x): - r"""Applies the Sigmoid Linear Unit. + r"""Applies the Sigmoid Linear Unit. Also known as Swish. Applies :math:`x \sigma(x)` element wise, where :math:`\sigma(\cdot)` is the logistic sigmoid. @@ -143,6 +143,41 @@ class Sigmoid(Module): pass +def step(x: mx.array, threshold: float = 0.0): + r"""Applies the Step Activation Function. + + This function implements a binary step activation, where the output is set + to 1 if the input is greater than a specified threshold, and 0 otherwise. + + .. math:: + \text{step}(x) = \begin{cases} + 0 & \text{if } x < \text{threshold} \\ + 1 & \text{if } x \geq \text{threshold} + \end{cases} + + Args: + threshold: The value to threshold at. + """ + + return mx.where(x > threshold, 1, 0) + + +def selu(x): + r"""Applies the Scaled Exponential Linear Unit. + + .. math:: + \text{selu}(x) = \begin{cases} + \lambda x & \text{if } x > 0 \\ + \lambda \alpha (\exp(x) - 1) & \text{if } x \leq 0 + \end{cases} + + where :math:`\lambda = 1.0507` and :math:`\alpha = 1.67326`. + + See also :func:`elu`. + """ + return elu(x, 1.67326) * 1.0507 + + @_make_activation_module(relu) class ReLU(Module): pass @@ -274,3 +309,32 @@ def tanh(x): @_make_activation_module(tanh) class Tanh(Module): pass + + +class Step(Module): + r"""Applies the Step Activation Function. + + This function implements a binary step activation, where the output is set + to 1 if the input is greater than a specified threshold, and 0 otherwise. + + .. math:: + \text{step}(x) = \begin{cases} + 0 & \text{if } x < \text{threshold} \\ + 1 & \text{if } x \geq \text{threshold} + \end{cases} + + Args: + threshold: The value to threshold at. + """ + + def __init__(self, threshold: float = 0.0): + super().__init__() + self.threshold = threshold + + def __call__(self, x: mx.array): + return step(x, self.threshold) + + +@_make_activation_module(selu) +class SELU(Module): + pass diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 8be69535f..ef3386bc1 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -449,6 +449,32 @@ class TestNN(mlx_tests.MLXTestCase): self.assertEqual(y.shape, [3]) self.assertEqual(y.dtype, mx.float32) + def test_step_activation(self): + x = mx.arange(-3, 4) + expected = mx.array([0, 0, 0, 0, 0, 1, 1]) + y = nn.Step()(x) + self.assertTrue(mx.array_equal(y, expected)) + + y = nn.Step(2)(x) + expected = mx.array([0, 0, 0, 0, 0, 0, 1]) + self.assertTrue(mx.array_equal(y, expected)) + + def test_selu(self): + x = mx.arange(-3, 4) + expected = mx.array( + [ + -1.670563817024231, + -1.5201621055603027, + -1.1113275289535522, + 0.0, + 1.0506999492645264, + 2.1013998985290527, + 3.152099847793579, + ] + ) + y = nn.SELU()(x) + self.assertTrue(mx.allclose(y, expected)) + if __name__ == "__main__": unittest.main()