Add Step, ELU, SELU, Swish activation functions (#117)

* Add Step, ELU, SELU, Swish activation functions

This commit adds the Step, ELU, SELU and Swish activations functions

* add to the docs

* review
This commit is contained in:
Nicholas Santavas 2023-12-12 02:04:07 +01:00 committed by GitHub
parent b9226c367c
commit f5df47ec6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 132 additions and 2 deletions

View File

@ -223,6 +223,20 @@ def topk(axis, x):
mx.eval(ys) mx.eval(ys)
def step_function(x):
y = x
for i in range(100):
y = nn.step(x)
mx.eval(y)
def selu(x):
y = x
for i in range(100):
y = nn.selu(x)
mx.eval(y)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("benchmark", help="Choose the benchmark to run") parser.add_argument("benchmark", help="Choose the benchmark to run")
@ -372,5 +386,11 @@ if __name__ == "__main__":
elif args.benchmark == "topk": elif args.benchmark == "topk":
print(bench(topk, axis, x)) print(bench(topk, axis, x))
elif args.benchmark == "step":
print(bench(step_function, x))
elif args.benchmark == "selu":
print(bench(selu, x))
else: else:
raise ValueError("Unknown benchmark") raise ValueError("Unknown benchmark")

View File

@ -257,6 +257,14 @@ def topk(axis, x):
sync_if_needed(x) sync_if_needed(x)
@torch.no_grad()
def selu(x):
y = x
for i in range(100):
y = torch.nn.functional.selu(y)
sync_if_needed(x)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("benchmark", help="Choose the benchmark to run") parser.add_argument("benchmark", help="Choose the benchmark to run")

View File

@ -205,6 +205,10 @@ if __name__ == "__main__":
compare_filtered("celu --size 32x16x1024 --cpu") compare_filtered("celu --size 32x16x1024 --cpu")
compare_filtered("log_sigmoid --size 32x16x1024") compare_filtered("log_sigmoid --size 32x16x1024")
compare_filtered("log_sigmoid --size 32x16x1024 --cpu") compare_filtered("log_sigmoid --size 32x16x1024 --cpu")
compare_filtered("step --size 32x16x1024")
compare_filtered("step --size 32x16x1024 --cpu")
compare_filtered("selu --size 32x16x1024")
compare_filtered("selu --size 32x16x1024 --cpu")
compare_filtered("scalar_mul --size 32x16x1024") compare_filtered("scalar_mul --size 32x16x1024")
compare_filtered("scalar_mul --size 32x16x1024 --cpu") compare_filtered("scalar_mul --size 32x16x1024 --cpu")
compare_filtered("cross_entropy --size 256x1024") compare_filtered("cross_entropy --size 256x1024")

View File

@ -97,7 +97,7 @@ Updating the parameters
MLX modules allow accessing and updating individual parameters. However, most MLX modules allow accessing and updating individual parameters. However, most
times we need to update large subsets of a module's parameters. This action is times we need to update large subsets of a module's parameters. This action is
performed by :meth:`Module.update`. performed by :meth:`Module.update`.
Value and grad Value and grad
-------------- --------------
@ -148,6 +148,8 @@ Neural Network Layers
ReLU ReLU
GELU GELU
SiLU SiLU
Step
SELU
Linear Linear
Conv1d Conv1d
Conv2d Conv2d
@ -170,6 +172,8 @@ simple functions.
gelu_fast_approx gelu_fast_approx
relu relu
silu silu
step
selu
Loss Functions Loss Functions
-------------- --------------

View File

@ -4,12 +4,14 @@ from mlx.nn.layers.activations import (
CELU, CELU,
ELU, ELU,
GELU, GELU,
SELU,
LeakyReLU, LeakyReLU,
LogSigmoid, LogSigmoid,
ReLU, ReLU,
ReLU6, ReLU6,
SiLU, SiLU,
Softplus, Softplus,
Step,
celu, celu,
elu, elu,
gelu, gelu,
@ -19,8 +21,10 @@ from mlx.nn.layers.activations import (
log_sigmoid, log_sigmoid,
relu, relu,
relu6, relu6,
selu,
silu, silu,
softplus, softplus,
step,
) )
from mlx.nn.layers.base import Module from mlx.nn.layers.base import Module
from mlx.nn.layers.containers import Sequential from mlx.nn.layers.containers import Sequential

View File

@ -74,7 +74,7 @@ def celu(x, alpha=1.0):
def silu(x): def silu(x):
r"""Applies the Sigmoid Linear Unit. r"""Applies the Sigmoid Linear Unit. Also known as Swish.
Applies :math:`x \sigma(x)` element wise, where :math:`\sigma(\cdot)` is Applies :math:`x \sigma(x)` element wise, where :math:`\sigma(\cdot)` is
the logistic sigmoid. the logistic sigmoid.
@ -143,6 +143,41 @@ class Sigmoid(Module):
pass pass
def step(x: mx.array, threshold: float = 0.0):
r"""Applies the Step Activation Function.
This function implements a binary step activation, where the output is set
to 1 if the input is greater than a specified threshold, and 0 otherwise.
.. math::
\text{step}(x) = \begin{cases}
0 & \text{if } x < \text{threshold} \\
1 & \text{if } x \geq \text{threshold}
\end{cases}
Args:
threshold: The value to threshold at.
"""
return mx.where(x > threshold, 1, 0)
def selu(x):
r"""Applies the Scaled Exponential Linear Unit.
.. math::
\text{selu}(x) = \begin{cases}
\lambda x & \text{if } x > 0 \\
\lambda \alpha (\exp(x) - 1) & \text{if } x \leq 0
\end{cases}
where :math:`\lambda = 1.0507` and :math:`\alpha = 1.67326`.
See also :func:`elu`.
"""
return elu(x, 1.67326) * 1.0507
@_make_activation_module(relu) @_make_activation_module(relu)
class ReLU(Module): class ReLU(Module):
pass pass
@ -274,3 +309,32 @@ def tanh(x):
@_make_activation_module(tanh) @_make_activation_module(tanh)
class Tanh(Module): class Tanh(Module):
pass pass
class Step(Module):
r"""Applies the Step Activation Function.
This function implements a binary step activation, where the output is set
to 1 if the input is greater than a specified threshold, and 0 otherwise.
.. math::
\text{step}(x) = \begin{cases}
0 & \text{if } x < \text{threshold} \\
1 & \text{if } x \geq \text{threshold}
\end{cases}
Args:
threshold: The value to threshold at.
"""
def __init__(self, threshold: float = 0.0):
super().__init__()
self.threshold = threshold
def __call__(self, x: mx.array):
return step(x, self.threshold)
@_make_activation_module(selu)
class SELU(Module):
pass

View File

@ -449,6 +449,32 @@ class TestNN(mlx_tests.MLXTestCase):
self.assertEqual(y.shape, [3]) self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32) self.assertEqual(y.dtype, mx.float32)
def test_step_activation(self):
x = mx.arange(-3, 4)
expected = mx.array([0, 0, 0, 0, 0, 1, 1])
y = nn.Step()(x)
self.assertTrue(mx.array_equal(y, expected))
y = nn.Step(2)(x)
expected = mx.array([0, 0, 0, 0, 0, 0, 1])
self.assertTrue(mx.array_equal(y, expected))
def test_selu(self):
x = mx.arange(-3, 4)
expected = mx.array(
[
-1.670563817024231,
-1.5201621055603027,
-1.1113275289535522,
0.0,
1.0506999492645264,
2.1013998985290527,
3.152099847793579,
]
)
y = nn.SELU()(x)
self.assertTrue(mx.allclose(y, expected))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()