mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Added GLU activation function and Gated activation function (#329)
* Added GLU activation function and gated activation function * Ran pre-commit * Ran pre commit * Removed old sigmoid implementation to match with main * Removed gated activation from __init__.py * Removed unused test cases * Removed unused imports * format / docstring --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
026ef9aae4
commit
022a944367
@ -4,6 +4,7 @@ from mlx.nn.layers.activations import (
|
||||
CELU,
|
||||
ELU,
|
||||
GELU,
|
||||
GLU,
|
||||
SELU,
|
||||
Hardswish,
|
||||
LeakyReLU,
|
||||
@ -25,6 +26,7 @@ from mlx.nn.layers.activations import (
|
||||
gelu,
|
||||
gelu_approx,
|
||||
gelu_fast_approx,
|
||||
glu,
|
||||
hardswish,
|
||||
leaky_relu,
|
||||
log_sigmoid,
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
@ -162,6 +163,43 @@ def gelu_fast_approx(x):
|
||||
return x * mx.sigmoid(1.773 * x)
|
||||
|
||||
|
||||
def glu(x: mx.array, axis: int = -1) -> mx.array:
|
||||
r"""Applies the gated linear unit function.
|
||||
|
||||
This function splits the ``axis`` dimension of the input into two halves
|
||||
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
|
||||
|
||||
.. math::
|
||||
textrm{GLU}(x) = a * \sigma(b)
|
||||
|
||||
Args:
|
||||
axis (int): The dimension to split along. Default: ``-1``.
|
||||
"""
|
||||
a, b = mx.split(x, indices_or_sections=2, axis=axis)
|
||||
return a * mx.sigmoid(b)
|
||||
|
||||
|
||||
class GLU(Module):
|
||||
r"""Applies the gated linear unit function.
|
||||
|
||||
This function splits the ``axis`` dimension of the input into two halves
|
||||
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
|
||||
|
||||
.. math::
|
||||
textrm{GLU}(x) = a * \sigma(b)
|
||||
|
||||
Args:
|
||||
axis (int): The dimension to split along. Default: ``-1``.
|
||||
"""
|
||||
|
||||
def __init__(self, axis: int = -1):
|
||||
super().__init__()
|
||||
self.axis = axis
|
||||
|
||||
def __call__(self, x) -> Any:
|
||||
return glu(x=x, axis=self.axis)
|
||||
|
||||
|
||||
def step(x: mx.array, threshold: float = 0.0):
|
||||
r"""Applies the Step Activation Function.
|
||||
|
||||
|
@ -789,6 +789,12 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y.shape, [5])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_glu(self):
|
||||
x = mx.array([[[1.0, 2.0, 3.0, 4.0]]], dtype=mx.float32)
|
||||
y = mx.array([[[0.952574, 1.96403]]], dtype=mx.float32)
|
||||
out = nn.glu(x)
|
||||
self.assertEqualArray(out, y)
|
||||
|
||||
def test_rope(self):
|
||||
for kwargs in [{}, {"traditional": False}, {"base": 10000}, {"scale": 0.25}]:
|
||||
rope = nn.RoPE(4, **kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user