From 75dc537e4426eab4874061ca1fa19308f9fc6ed4 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 4 Jan 2024 13:16:36 -0800 Subject: [PATCH] Fix the sigmoid module (#371) --- python/mlx/nn/layers/__init__.py | 1 + python/mlx/nn/layers/activations.py | 18 +++++++++--------- python/tests/test_nn.py | 9 +++++++++ 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 80d500c9f..6e56efaab 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -13,6 +13,7 @@ from mlx.nn.layers.activations import ( PReLU, ReLU, ReLU6, + Sigmoid, SiLU, Softmax, Softplus, diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index 580f0b261..02e37e116 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -162,15 +162,6 @@ def gelu_fast_approx(x): return x * mx.sigmoid(1.773 * x) -@_make_activation_module -class Sigmoid(Module): - r"""Applies the sigmoid function, element-wise. - - .. math:: - \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)} - """ - - def step(x: mx.array, threshold: float = 0.0): r"""Applies the Step Activation Function. @@ -240,6 +231,15 @@ def hardswish(x): return x * mx.minimum(max_x_3, 6) / 6 +@_make_activation_module(mx.sigmoid) +class Sigmoid(Module): + r"""Applies the sigmoid function, element-wise. + + .. math:: + \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)} + """ + + @_make_activation_module(mish) class Mish(Module): r"""Applies the Mish function, element-wise. diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 28e72a7e7..0941c95ad 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -657,6 +657,15 @@ class TestLayers(mlx_tests.MLXTestCase): mx.abs(similarities[mx.arange(10), mx.arange(10)] - 1).max(), 1e-5 ) + def test_sigmoid(self): + x = mx.array([1.0, 0.0, -1.0]) + y1 = mx.sigmoid(x) + y2 = nn.activations.sigmoid(x) + y3 = nn.Sigmoid()(x) + + self.assertEqualArray(y1, y2, atol=0, rtol=0) + self.assertEqualArray(y1, y3, atol=0, rtol=0) + def test_relu(self): x = mx.array([1.0, -1.0, 0.0]) y = nn.relu(x)