mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 18:11:17 +08:00
Add Step, ELU, SELU, Swish activation functions (#117)
* Add Step, ELU, SELU, Swish activation functions This commit adds the Step, ELU, SELU and Swish activations functions * add to the docs * review
This commit is contained in:
parent
b9226c367c
commit
f5df47ec6e
@ -223,6 +223,20 @@ def topk(axis, x):
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def step_function(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.step(x)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def selu(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.selu(x)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("benchmark", help="Choose the benchmark to run")
|
||||
@ -372,5 +386,11 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "topk":
|
||||
print(bench(topk, axis, x))
|
||||
|
||||
elif args.benchmark == "step":
|
||||
print(bench(step_function, x))
|
||||
|
||||
elif args.benchmark == "selu":
|
||||
print(bench(selu, x))
|
||||
|
||||
else:
|
||||
raise ValueError("Unknown benchmark")
|
||||
|
@ -257,6 +257,14 @@ def topk(axis, x):
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def selu(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = torch.nn.functional.selu(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("benchmark", help="Choose the benchmark to run")
|
||||
|
@ -205,6 +205,10 @@ if __name__ == "__main__":
|
||||
compare_filtered("celu --size 32x16x1024 --cpu")
|
||||
compare_filtered("log_sigmoid --size 32x16x1024")
|
||||
compare_filtered("log_sigmoid --size 32x16x1024 --cpu")
|
||||
compare_filtered("step --size 32x16x1024")
|
||||
compare_filtered("step --size 32x16x1024 --cpu")
|
||||
compare_filtered("selu --size 32x16x1024")
|
||||
compare_filtered("selu --size 32x16x1024 --cpu")
|
||||
compare_filtered("scalar_mul --size 32x16x1024")
|
||||
compare_filtered("scalar_mul --size 32x16x1024 --cpu")
|
||||
compare_filtered("cross_entropy --size 256x1024")
|
||||
|
@ -97,7 +97,7 @@ Updating the parameters
|
||||
|
||||
MLX modules allow accessing and updating individual parameters. However, most
|
||||
times we need to update large subsets of a module's parameters. This action is
|
||||
performed by :meth:`Module.update`.
|
||||
performed by :meth:`Module.update`.
|
||||
|
||||
Value and grad
|
||||
--------------
|
||||
@ -148,6 +148,8 @@ Neural Network Layers
|
||||
ReLU
|
||||
GELU
|
||||
SiLU
|
||||
Step
|
||||
SELU
|
||||
Linear
|
||||
Conv1d
|
||||
Conv2d
|
||||
@ -170,6 +172,8 @@ simple functions.
|
||||
gelu_fast_approx
|
||||
relu
|
||||
silu
|
||||
step
|
||||
selu
|
||||
|
||||
Loss Functions
|
||||
--------------
|
||||
|
@ -4,12 +4,14 @@ from mlx.nn.layers.activations import (
|
||||
CELU,
|
||||
ELU,
|
||||
GELU,
|
||||
SELU,
|
||||
LeakyReLU,
|
||||
LogSigmoid,
|
||||
ReLU,
|
||||
ReLU6,
|
||||
SiLU,
|
||||
Softplus,
|
||||
Step,
|
||||
celu,
|
||||
elu,
|
||||
gelu,
|
||||
@ -19,8 +21,10 @@ from mlx.nn.layers.activations import (
|
||||
log_sigmoid,
|
||||
relu,
|
||||
relu6,
|
||||
selu,
|
||||
silu,
|
||||
softplus,
|
||||
step,
|
||||
)
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.containers import Sequential
|
||||
|
@ -74,7 +74,7 @@ def celu(x, alpha=1.0):
|
||||
|
||||
|
||||
def silu(x):
|
||||
r"""Applies the Sigmoid Linear Unit.
|
||||
r"""Applies the Sigmoid Linear Unit. Also known as Swish.
|
||||
|
||||
Applies :math:`x \sigma(x)` element wise, where :math:`\sigma(\cdot)` is
|
||||
the logistic sigmoid.
|
||||
@ -143,6 +143,41 @@ class Sigmoid(Module):
|
||||
pass
|
||||
|
||||
|
||||
def step(x: mx.array, threshold: float = 0.0):
|
||||
r"""Applies the Step Activation Function.
|
||||
|
||||
This function implements a binary step activation, where the output is set
|
||||
to 1 if the input is greater than a specified threshold, and 0 otherwise.
|
||||
|
||||
.. math::
|
||||
\text{step}(x) = \begin{cases}
|
||||
0 & \text{if } x < \text{threshold} \\
|
||||
1 & \text{if } x \geq \text{threshold}
|
||||
\end{cases}
|
||||
|
||||
Args:
|
||||
threshold: The value to threshold at.
|
||||
"""
|
||||
|
||||
return mx.where(x > threshold, 1, 0)
|
||||
|
||||
|
||||
def selu(x):
|
||||
r"""Applies the Scaled Exponential Linear Unit.
|
||||
|
||||
.. math::
|
||||
\text{selu}(x) = \begin{cases}
|
||||
\lambda x & \text{if } x > 0 \\
|
||||
\lambda \alpha (\exp(x) - 1) & \text{if } x \leq 0
|
||||
\end{cases}
|
||||
|
||||
where :math:`\lambda = 1.0507` and :math:`\alpha = 1.67326`.
|
||||
|
||||
See also :func:`elu`.
|
||||
"""
|
||||
return elu(x, 1.67326) * 1.0507
|
||||
|
||||
|
||||
@_make_activation_module(relu)
|
||||
class ReLU(Module):
|
||||
pass
|
||||
@ -274,3 +309,32 @@ def tanh(x):
|
||||
@_make_activation_module(tanh)
|
||||
class Tanh(Module):
|
||||
pass
|
||||
|
||||
|
||||
class Step(Module):
|
||||
r"""Applies the Step Activation Function.
|
||||
|
||||
This function implements a binary step activation, where the output is set
|
||||
to 1 if the input is greater than a specified threshold, and 0 otherwise.
|
||||
|
||||
.. math::
|
||||
\text{step}(x) = \begin{cases}
|
||||
0 & \text{if } x < \text{threshold} \\
|
||||
1 & \text{if } x \geq \text{threshold}
|
||||
\end{cases}
|
||||
|
||||
Args:
|
||||
threshold: The value to threshold at.
|
||||
"""
|
||||
|
||||
def __init__(self, threshold: float = 0.0):
|
||||
super().__init__()
|
||||
self.threshold = threshold
|
||||
|
||||
def __call__(self, x: mx.array):
|
||||
return step(x, self.threshold)
|
||||
|
||||
|
||||
@_make_activation_module(selu)
|
||||
class SELU(Module):
|
||||
pass
|
||||
|
@ -449,6 +449,32 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_step_activation(self):
|
||||
x = mx.arange(-3, 4)
|
||||
expected = mx.array([0, 0, 0, 0, 0, 1, 1])
|
||||
y = nn.Step()(x)
|
||||
self.assertTrue(mx.array_equal(y, expected))
|
||||
|
||||
y = nn.Step(2)(x)
|
||||
expected = mx.array([0, 0, 0, 0, 0, 0, 1])
|
||||
self.assertTrue(mx.array_equal(y, expected))
|
||||
|
||||
def test_selu(self):
|
||||
x = mx.arange(-3, 4)
|
||||
expected = mx.array(
|
||||
[
|
||||
-1.670563817024231,
|
||||
-1.5201621055603027,
|
||||
-1.1113275289535522,
|
||||
0.0,
|
||||
1.0506999492645264,
|
||||
2.1013998985290527,
|
||||
3.152099847793579,
|
||||
]
|
||||
)
|
||||
y = nn.SELU()(x)
|
||||
self.assertTrue(mx.allclose(y, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user