Added GLU activation function and Gated activation function (#329)

* Added GLU activation function and gated activation function

* Ran pre-commit

* Ran pre commit

* Removed old sigmoid implementation to match with main

* Removed gated activation from __init__.py

* Removed unused test cases

* Removed unused imports

* format / docstring

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Hazem Essam 2024-01-08 16:13:16 +02:00 committed by GitHub
parent 026ef9aae4
commit 022a944367
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 46 additions and 0 deletions

View File

@ -4,6 +4,7 @@ from mlx.nn.layers.activations import (
CELU,
ELU,
GELU,
GLU,
SELU,
Hardswish,
LeakyReLU,
@ -25,6 +26,7 @@ from mlx.nn.layers.activations import (
gelu,
gelu_approx,
gelu_fast_approx,
glu,
hardswish,
leaky_relu,
log_sigmoid,

View File

@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc.
import math
from typing import Any
import mlx.core as mx
from mlx.nn.layers.base import Module
@ -162,6 +163,43 @@ def gelu_fast_approx(x):
return x * mx.sigmoid(1.773 * x)
def glu(x: mx.array, axis: int = -1) -> mx.array:
r"""Applies the gated linear unit function.
This function splits the ``axis`` dimension of the input into two halves
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
.. math::
textrm{GLU}(x) = a * \sigma(b)
Args:
axis (int): The dimension to split along. Default: ``-1``.
"""
a, b = mx.split(x, indices_or_sections=2, axis=axis)
return a * mx.sigmoid(b)
class GLU(Module):
r"""Applies the gated linear unit function.
This function splits the ``axis`` dimension of the input into two halves
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
.. math::
textrm{GLU}(x) = a * \sigma(b)
Args:
axis (int): The dimension to split along. Default: ``-1``.
"""
def __init__(self, axis: int = -1):
super().__init__()
self.axis = axis
def __call__(self, x) -> Any:
return glu(x=x, axis=self.axis)
def step(x: mx.array, threshold: float = 0.0):
r"""Applies the Step Activation Function.

View File

@ -789,6 +789,12 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertEqual(y.shape, [5])
self.assertEqual(y.dtype, mx.float32)
def test_glu(self):
x = mx.array([[[1.0, 2.0, 3.0, 4.0]]], dtype=mx.float32)
y = mx.array([[[0.952574, 1.96403]]], dtype=mx.float32)
out = nn.glu(x)
self.assertEqualArray(out, y)
def test_rope(self):
for kwargs in [{}, {"traditional": False}, {"base": 10000}, {"scale": 0.25}]:
rope = nn.RoPE(4, **kwargs)