mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Added GLU activation function and Gated activation function (#329)
* Added GLU activation function and gated activation function * Ran pre-commit * Ran pre commit * Removed old sigmoid implementation to match with main * Removed gated activation from __init__.py * Removed unused test cases * Removed unused imports * format / docstring --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		| @@ -4,6 +4,7 @@ from mlx.nn.layers.activations import ( | ||||
|     CELU, | ||||
|     ELU, | ||||
|     GELU, | ||||
|     GLU, | ||||
|     SELU, | ||||
|     Hardswish, | ||||
|     LeakyReLU, | ||||
| @@ -25,6 +26,7 @@ from mlx.nn.layers.activations import ( | ||||
|     gelu, | ||||
|     gelu_approx, | ||||
|     gelu_fast_approx, | ||||
|     glu, | ||||
|     hardswish, | ||||
|     leaky_relu, | ||||
|     log_sigmoid, | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| import math | ||||
| from typing import Any | ||||
|  | ||||
| import mlx.core as mx | ||||
| from mlx.nn.layers.base import Module | ||||
| @@ -162,6 +163,43 @@ def gelu_fast_approx(x): | ||||
|     return x * mx.sigmoid(1.773 * x) | ||||
|  | ||||
|  | ||||
| def glu(x: mx.array, axis: int = -1) -> mx.array: | ||||
|     r"""Applies the gated linear unit function. | ||||
|  | ||||
|     This function splits the ``axis`` dimension of the input into two halves | ||||
|     (:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`. | ||||
|  | ||||
|     .. math:: | ||||
|         textrm{GLU}(x) = a * \sigma(b) | ||||
|  | ||||
|     Args: | ||||
|         axis (int): The dimension to split along. Default: ``-1``. | ||||
|     """ | ||||
|     a, b = mx.split(x, indices_or_sections=2, axis=axis) | ||||
|     return a * mx.sigmoid(b) | ||||
|  | ||||
|  | ||||
| class GLU(Module): | ||||
|     r"""Applies the gated linear unit function. | ||||
|  | ||||
|     This function splits the ``axis`` dimension of the input into two halves | ||||
|     (:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`. | ||||
|  | ||||
|     .. math:: | ||||
|         textrm{GLU}(x) = a * \sigma(b) | ||||
|  | ||||
|     Args: | ||||
|         axis (int): The dimension to split along. Default: ``-1``. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, axis: int = -1): | ||||
|         super().__init__() | ||||
|         self.axis = axis | ||||
|  | ||||
|     def __call__(self, x) -> Any: | ||||
|         return glu(x=x, axis=self.axis) | ||||
|  | ||||
|  | ||||
| def step(x: mx.array, threshold: float = 0.0): | ||||
|     r"""Applies the Step Activation Function. | ||||
|  | ||||
|   | ||||
| @@ -789,6 +789,12 @@ class TestLayers(mlx_tests.MLXTestCase): | ||||
|         self.assertEqual(y.shape, [5]) | ||||
|         self.assertEqual(y.dtype, mx.float32) | ||||
|  | ||||
|     def test_glu(self): | ||||
|         x = mx.array([[[1.0, 2.0, 3.0, 4.0]]], dtype=mx.float32) | ||||
|         y = mx.array([[[0.952574, 1.96403]]], dtype=mx.float32) | ||||
|         out = nn.glu(x) | ||||
|         self.assertEqualArray(out, y) | ||||
|  | ||||
|     def test_rope(self): | ||||
|         for kwargs in [{}, {"traditional": False}, {"base": 10000}, {"scale": 0.25}]: | ||||
|             rope = nn.RoPE(4, **kwargs) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Hazem Essam
					Hazem Essam