mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 02:41:19 +08:00
Add Step, ELU, SELU, Swish activation functions (#117)
* Add Step, ELU, SELU, Swish activation functions This commit adds the Step, ELU, SELU and Swish activations functions * add to the docs * review
This commit is contained in:
parent
b9226c367c
commit
f5df47ec6e
@ -223,6 +223,20 @@ def topk(axis, x):
|
|||||||
mx.eval(ys)
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
|
def step_function(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = nn.step(x)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def selu(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = nn.selu(x)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("benchmark", help="Choose the benchmark to run")
|
parser.add_argument("benchmark", help="Choose the benchmark to run")
|
||||||
@ -372,5 +386,11 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "topk":
|
elif args.benchmark == "topk":
|
||||||
print(bench(topk, axis, x))
|
print(bench(topk, axis, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "step":
|
||||||
|
print(bench(step_function, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "selu":
|
||||||
|
print(bench(selu, x))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown benchmark")
|
raise ValueError("Unknown benchmark")
|
||||||
|
@ -257,6 +257,14 @@ def topk(axis, x):
|
|||||||
sync_if_needed(x)
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def selu(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = torch.nn.functional.selu(y)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("benchmark", help="Choose the benchmark to run")
|
parser.add_argument("benchmark", help="Choose the benchmark to run")
|
||||||
|
@ -205,6 +205,10 @@ if __name__ == "__main__":
|
|||||||
compare_filtered("celu --size 32x16x1024 --cpu")
|
compare_filtered("celu --size 32x16x1024 --cpu")
|
||||||
compare_filtered("log_sigmoid --size 32x16x1024")
|
compare_filtered("log_sigmoid --size 32x16x1024")
|
||||||
compare_filtered("log_sigmoid --size 32x16x1024 --cpu")
|
compare_filtered("log_sigmoid --size 32x16x1024 --cpu")
|
||||||
|
compare_filtered("step --size 32x16x1024")
|
||||||
|
compare_filtered("step --size 32x16x1024 --cpu")
|
||||||
|
compare_filtered("selu --size 32x16x1024")
|
||||||
|
compare_filtered("selu --size 32x16x1024 --cpu")
|
||||||
compare_filtered("scalar_mul --size 32x16x1024")
|
compare_filtered("scalar_mul --size 32x16x1024")
|
||||||
compare_filtered("scalar_mul --size 32x16x1024 --cpu")
|
compare_filtered("scalar_mul --size 32x16x1024 --cpu")
|
||||||
compare_filtered("cross_entropy --size 256x1024")
|
compare_filtered("cross_entropy --size 256x1024")
|
||||||
|
@ -97,7 +97,7 @@ Updating the parameters
|
|||||||
|
|
||||||
MLX modules allow accessing and updating individual parameters. However, most
|
MLX modules allow accessing and updating individual parameters. However, most
|
||||||
times we need to update large subsets of a module's parameters. This action is
|
times we need to update large subsets of a module's parameters. This action is
|
||||||
performed by :meth:`Module.update`.
|
performed by :meth:`Module.update`.
|
||||||
|
|
||||||
Value and grad
|
Value and grad
|
||||||
--------------
|
--------------
|
||||||
@ -148,6 +148,8 @@ Neural Network Layers
|
|||||||
ReLU
|
ReLU
|
||||||
GELU
|
GELU
|
||||||
SiLU
|
SiLU
|
||||||
|
Step
|
||||||
|
SELU
|
||||||
Linear
|
Linear
|
||||||
Conv1d
|
Conv1d
|
||||||
Conv2d
|
Conv2d
|
||||||
@ -170,6 +172,8 @@ simple functions.
|
|||||||
gelu_fast_approx
|
gelu_fast_approx
|
||||||
relu
|
relu
|
||||||
silu
|
silu
|
||||||
|
step
|
||||||
|
selu
|
||||||
|
|
||||||
Loss Functions
|
Loss Functions
|
||||||
--------------
|
--------------
|
||||||
|
@ -4,12 +4,14 @@ from mlx.nn.layers.activations import (
|
|||||||
CELU,
|
CELU,
|
||||||
ELU,
|
ELU,
|
||||||
GELU,
|
GELU,
|
||||||
|
SELU,
|
||||||
LeakyReLU,
|
LeakyReLU,
|
||||||
LogSigmoid,
|
LogSigmoid,
|
||||||
ReLU,
|
ReLU,
|
||||||
ReLU6,
|
ReLU6,
|
||||||
SiLU,
|
SiLU,
|
||||||
Softplus,
|
Softplus,
|
||||||
|
Step,
|
||||||
celu,
|
celu,
|
||||||
elu,
|
elu,
|
||||||
gelu,
|
gelu,
|
||||||
@ -19,8 +21,10 @@ from mlx.nn.layers.activations import (
|
|||||||
log_sigmoid,
|
log_sigmoid,
|
||||||
relu,
|
relu,
|
||||||
relu6,
|
relu6,
|
||||||
|
selu,
|
||||||
silu,
|
silu,
|
||||||
softplus,
|
softplus,
|
||||||
|
step,
|
||||||
)
|
)
|
||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
from mlx.nn.layers.containers import Sequential
|
from mlx.nn.layers.containers import Sequential
|
||||||
|
@ -74,7 +74,7 @@ def celu(x, alpha=1.0):
|
|||||||
|
|
||||||
|
|
||||||
def silu(x):
|
def silu(x):
|
||||||
r"""Applies the Sigmoid Linear Unit.
|
r"""Applies the Sigmoid Linear Unit. Also known as Swish.
|
||||||
|
|
||||||
Applies :math:`x \sigma(x)` element wise, where :math:`\sigma(\cdot)` is
|
Applies :math:`x \sigma(x)` element wise, where :math:`\sigma(\cdot)` is
|
||||||
the logistic sigmoid.
|
the logistic sigmoid.
|
||||||
@ -143,6 +143,41 @@ class Sigmoid(Module):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def step(x: mx.array, threshold: float = 0.0):
|
||||||
|
r"""Applies the Step Activation Function.
|
||||||
|
|
||||||
|
This function implements a binary step activation, where the output is set
|
||||||
|
to 1 if the input is greater than a specified threshold, and 0 otherwise.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{step}(x) = \begin{cases}
|
||||||
|
0 & \text{if } x < \text{threshold} \\
|
||||||
|
1 & \text{if } x \geq \text{threshold}
|
||||||
|
\end{cases}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
threshold: The value to threshold at.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return mx.where(x > threshold, 1, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def selu(x):
|
||||||
|
r"""Applies the Scaled Exponential Linear Unit.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{selu}(x) = \begin{cases}
|
||||||
|
\lambda x & \text{if } x > 0 \\
|
||||||
|
\lambda \alpha (\exp(x) - 1) & \text{if } x \leq 0
|
||||||
|
\end{cases}
|
||||||
|
|
||||||
|
where :math:`\lambda = 1.0507` and :math:`\alpha = 1.67326`.
|
||||||
|
|
||||||
|
See also :func:`elu`.
|
||||||
|
"""
|
||||||
|
return elu(x, 1.67326) * 1.0507
|
||||||
|
|
||||||
|
|
||||||
@_make_activation_module(relu)
|
@_make_activation_module(relu)
|
||||||
class ReLU(Module):
|
class ReLU(Module):
|
||||||
pass
|
pass
|
||||||
@ -274,3 +309,32 @@ def tanh(x):
|
|||||||
@_make_activation_module(tanh)
|
@_make_activation_module(tanh)
|
||||||
class Tanh(Module):
|
class Tanh(Module):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Step(Module):
|
||||||
|
r"""Applies the Step Activation Function.
|
||||||
|
|
||||||
|
This function implements a binary step activation, where the output is set
|
||||||
|
to 1 if the input is greater than a specified threshold, and 0 otherwise.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{step}(x) = \begin{cases}
|
||||||
|
0 & \text{if } x < \text{threshold} \\
|
||||||
|
1 & \text{if } x \geq \text{threshold}
|
||||||
|
\end{cases}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
threshold: The value to threshold at.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, threshold: float = 0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.threshold = threshold
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array):
|
||||||
|
return step(x, self.threshold)
|
||||||
|
|
||||||
|
|
||||||
|
@_make_activation_module(selu)
|
||||||
|
class SELU(Module):
|
||||||
|
pass
|
||||||
|
@ -449,6 +449,32 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(y.shape, [3])
|
self.assertEqual(y.shape, [3])
|
||||||
self.assertEqual(y.dtype, mx.float32)
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
def test_step_activation(self):
|
||||||
|
x = mx.arange(-3, 4)
|
||||||
|
expected = mx.array([0, 0, 0, 0, 0, 1, 1])
|
||||||
|
y = nn.Step()(x)
|
||||||
|
self.assertTrue(mx.array_equal(y, expected))
|
||||||
|
|
||||||
|
y = nn.Step(2)(x)
|
||||||
|
expected = mx.array([0, 0, 0, 0, 0, 0, 1])
|
||||||
|
self.assertTrue(mx.array_equal(y, expected))
|
||||||
|
|
||||||
|
def test_selu(self):
|
||||||
|
x = mx.arange(-3, 4)
|
||||||
|
expected = mx.array(
|
||||||
|
[
|
||||||
|
-1.670563817024231,
|
||||||
|
-1.5201621055603027,
|
||||||
|
-1.1113275289535522,
|
||||||
|
0.0,
|
||||||
|
1.0506999492645264,
|
||||||
|
2.1013998985290527,
|
||||||
|
3.152099847793579,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
y = nn.SELU()(x)
|
||||||
|
self.assertTrue(mx.allclose(y, expected))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user