Add Step, ELU, SELU, Swish activation functions (#117)

* Add Step, ELU, SELU, Swish activation functions

This commit adds the Step, ELU, SELU and Swish activations functions

* add to the docs

* review
This commit is contained in:
Nicholas Santavas 2023-12-12 02:04:07 +01:00 committed by GitHub
parent b9226c367c
commit f5df47ec6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 132 additions and 2 deletions

View File

@ -223,6 +223,20 @@ def topk(axis, x):
mx.eval(ys)
def step_function(x):
y = x
for i in range(100):
y = nn.step(x)
mx.eval(y)
def selu(x):
y = x
for i in range(100):
y = nn.selu(x)
mx.eval(y)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("benchmark", help="Choose the benchmark to run")
@ -372,5 +386,11 @@ if __name__ == "__main__":
elif args.benchmark == "topk":
print(bench(topk, axis, x))
elif args.benchmark == "step":
print(bench(step_function, x))
elif args.benchmark == "selu":
print(bench(selu, x))
else:
raise ValueError("Unknown benchmark")

View File

@ -257,6 +257,14 @@ def topk(axis, x):
sync_if_needed(x)
@torch.no_grad()
def selu(x):
y = x
for i in range(100):
y = torch.nn.functional.selu(y)
sync_if_needed(x)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("benchmark", help="Choose the benchmark to run")

View File

@ -205,6 +205,10 @@ if __name__ == "__main__":
compare_filtered("celu --size 32x16x1024 --cpu")
compare_filtered("log_sigmoid --size 32x16x1024")
compare_filtered("log_sigmoid --size 32x16x1024 --cpu")
compare_filtered("step --size 32x16x1024")
compare_filtered("step --size 32x16x1024 --cpu")
compare_filtered("selu --size 32x16x1024")
compare_filtered("selu --size 32x16x1024 --cpu")
compare_filtered("scalar_mul --size 32x16x1024")
compare_filtered("scalar_mul --size 32x16x1024 --cpu")
compare_filtered("cross_entropy --size 256x1024")

View File

@ -97,7 +97,7 @@ Updating the parameters
MLX modules allow accessing and updating individual parameters. However, most
times we need to update large subsets of a module's parameters. This action is
performed by :meth:`Module.update`.
performed by :meth:`Module.update`.
Value and grad
--------------
@ -148,6 +148,8 @@ Neural Network Layers
ReLU
GELU
SiLU
Step
SELU
Linear
Conv1d
Conv2d
@ -170,6 +172,8 @@ simple functions.
gelu_fast_approx
relu
silu
step
selu
Loss Functions
--------------

View File

@ -4,12 +4,14 @@ from mlx.nn.layers.activations import (
CELU,
ELU,
GELU,
SELU,
LeakyReLU,
LogSigmoid,
ReLU,
ReLU6,
SiLU,
Softplus,
Step,
celu,
elu,
gelu,
@ -19,8 +21,10 @@ from mlx.nn.layers.activations import (
log_sigmoid,
relu,
relu6,
selu,
silu,
softplus,
step,
)
from mlx.nn.layers.base import Module
from mlx.nn.layers.containers import Sequential

View File

@ -74,7 +74,7 @@ def celu(x, alpha=1.0):
def silu(x):
r"""Applies the Sigmoid Linear Unit.
r"""Applies the Sigmoid Linear Unit. Also known as Swish.
Applies :math:`x \sigma(x)` element wise, where :math:`\sigma(\cdot)` is
the logistic sigmoid.
@ -143,6 +143,41 @@ class Sigmoid(Module):
pass
def step(x: mx.array, threshold: float = 0.0):
r"""Applies the Step Activation Function.
This function implements a binary step activation, where the output is set
to 1 if the input is greater than a specified threshold, and 0 otherwise.
.. math::
\text{step}(x) = \begin{cases}
0 & \text{if } x < \text{threshold} \\
1 & \text{if } x \geq \text{threshold}
\end{cases}
Args:
threshold: The value to threshold at.
"""
return mx.where(x > threshold, 1, 0)
def selu(x):
r"""Applies the Scaled Exponential Linear Unit.
.. math::
\text{selu}(x) = \begin{cases}
\lambda x & \text{if } x > 0 \\
\lambda \alpha (\exp(x) - 1) & \text{if } x \leq 0
\end{cases}
where :math:`\lambda = 1.0507` and :math:`\alpha = 1.67326`.
See also :func:`elu`.
"""
return elu(x, 1.67326) * 1.0507
@_make_activation_module(relu)
class ReLU(Module):
pass
@ -274,3 +309,32 @@ def tanh(x):
@_make_activation_module(tanh)
class Tanh(Module):
pass
class Step(Module):
r"""Applies the Step Activation Function.
This function implements a binary step activation, where the output is set
to 1 if the input is greater than a specified threshold, and 0 otherwise.
.. math::
\text{step}(x) = \begin{cases}
0 & \text{if } x < \text{threshold} \\
1 & \text{if } x \geq \text{threshold}
\end{cases}
Args:
threshold: The value to threshold at.
"""
def __init__(self, threshold: float = 0.0):
super().__init__()
self.threshold = threshold
def __call__(self, x: mx.array):
return step(x, self.threshold)
@_make_activation_module(selu)
class SELU(Module):
pass

View File

@ -449,6 +449,32 @@ class TestNN(mlx_tests.MLXTestCase):
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_step_activation(self):
x = mx.arange(-3, 4)
expected = mx.array([0, 0, 0, 0, 0, 1, 1])
y = nn.Step()(x)
self.assertTrue(mx.array_equal(y, expected))
y = nn.Step(2)(x)
expected = mx.array([0, 0, 0, 0, 0, 0, 1])
self.assertTrue(mx.array_equal(y, expected))
def test_selu(self):
x = mx.arange(-3, 4)
expected = mx.array(
[
-1.670563817024231,
-1.5201621055603027,
-1.1113275289535522,
0.0,
1.0506999492645264,
2.1013998985290527,
3.152099847793579,
]
)
y = nn.SELU()(x)
self.assertTrue(mx.allclose(y, expected))
if __name__ == "__main__":
unittest.main()