From 022a9443676c47fdfee9416bebf0113c67146ca1 Mon Sep 17 00:00:00 2001 From: Hazem Essam Date: Mon, 8 Jan 2024 16:13:16 +0200 Subject: [PATCH] Added GLU activation function and Gated activation function (#329) * Added GLU activation function and gated activation function * Ran pre-commit * Ran pre commit * Removed old sigmoid implementation to match with main * Removed gated activation from __init__.py * Removed unused test cases * Removed unused imports * format / docstring --------- Co-authored-by: Awni Hannun --- python/mlx/nn/layers/__init__.py | 2 ++ python/mlx/nn/layers/activations.py | 38 +++++++++++++++++++++++++++++ python/tests/test_nn.py | 6 +++++ 3 files changed, 46 insertions(+) diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 6e56efaab..00832f525 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -4,6 +4,7 @@ from mlx.nn.layers.activations import ( CELU, ELU, GELU, + GLU, SELU, Hardswish, LeakyReLU, @@ -25,6 +26,7 @@ from mlx.nn.layers.activations import ( gelu, gelu_approx, gelu_fast_approx, + glu, hardswish, leaky_relu, log_sigmoid, diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index 02e37e116..87d2fbac6 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. import math +from typing import Any import mlx.core as mx from mlx.nn.layers.base import Module @@ -162,6 +163,43 @@ def gelu_fast_approx(x): return x * mx.sigmoid(1.773 * x) +def glu(x: mx.array, axis: int = -1) -> mx.array: + r"""Applies the gated linear unit function. + + This function splits the ``axis`` dimension of the input into two halves + (:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`. + + .. math:: + textrm{GLU}(x) = a * \sigma(b) + + Args: + axis (int): The dimension to split along. Default: ``-1``. + """ + a, b = mx.split(x, indices_or_sections=2, axis=axis) + return a * mx.sigmoid(b) + + +class GLU(Module): + r"""Applies the gated linear unit function. + + This function splits the ``axis`` dimension of the input into two halves + (:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`. + + .. math:: + textrm{GLU}(x) = a * \sigma(b) + + Args: + axis (int): The dimension to split along. Default: ``-1``. + """ + + def __init__(self, axis: int = -1): + super().__init__() + self.axis = axis + + def __call__(self, x) -> Any: + return glu(x=x, axis=self.axis) + + def step(x: mx.array, threshold: float = 0.0): r"""Applies the Step Activation Function. diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 0941c95ad..04d849a42 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -789,6 +789,12 @@ class TestLayers(mlx_tests.MLXTestCase): self.assertEqual(y.shape, [5]) self.assertEqual(y.dtype, mx.float32) + def test_glu(self): + x = mx.array([[[1.0, 2.0, 3.0, 4.0]]], dtype=mx.float32) + y = mx.array([[[0.952574, 1.96403]]], dtype=mx.float32) + out = nn.glu(x) + self.assertEqualArray(out, y) + def test_rope(self): for kwargs in [{}, {"traditional": False}, {"base": 10000}, {"scale": 0.25}]: rope = nn.RoPE(4, **kwargs)