mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Added GLU activation function and Gated activation function (#329)
* Added GLU activation function and gated activation function * Ran pre-commit * Ran pre commit * Removed old sigmoid implementation to match with main * Removed gated activation from __init__.py * Removed unused test cases * Removed unused imports * format / docstring --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
026ef9aae4
commit
022a944367
@ -4,6 +4,7 @@ from mlx.nn.layers.activations import (
|
|||||||
CELU,
|
CELU,
|
||||||
ELU,
|
ELU,
|
||||||
GELU,
|
GELU,
|
||||||
|
GLU,
|
||||||
SELU,
|
SELU,
|
||||||
Hardswish,
|
Hardswish,
|
||||||
LeakyReLU,
|
LeakyReLU,
|
||||||
@ -25,6 +26,7 @@ from mlx.nn.layers.activations import (
|
|||||||
gelu,
|
gelu,
|
||||||
gelu_approx,
|
gelu_approx,
|
||||||
gelu_fast_approx,
|
gelu_fast_approx,
|
||||||
|
glu,
|
||||||
hardswish,
|
hardswish,
|
||||||
leaky_relu,
|
leaky_relu,
|
||||||
log_sigmoid,
|
log_sigmoid,
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
@ -162,6 +163,43 @@ def gelu_fast_approx(x):
|
|||||||
return x * mx.sigmoid(1.773 * x)
|
return x * mx.sigmoid(1.773 * x)
|
||||||
|
|
||||||
|
|
||||||
|
def glu(x: mx.array, axis: int = -1) -> mx.array:
|
||||||
|
r"""Applies the gated linear unit function.
|
||||||
|
|
||||||
|
This function splits the ``axis`` dimension of the input into two halves
|
||||||
|
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
textrm{GLU}(x) = a * \sigma(b)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
axis (int): The dimension to split along. Default: ``-1``.
|
||||||
|
"""
|
||||||
|
a, b = mx.split(x, indices_or_sections=2, axis=axis)
|
||||||
|
return a * mx.sigmoid(b)
|
||||||
|
|
||||||
|
|
||||||
|
class GLU(Module):
|
||||||
|
r"""Applies the gated linear unit function.
|
||||||
|
|
||||||
|
This function splits the ``axis`` dimension of the input into two halves
|
||||||
|
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
textrm{GLU}(x) = a * \sigma(b)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
axis (int): The dimension to split along. Default: ``-1``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, axis: int = -1):
|
||||||
|
super().__init__()
|
||||||
|
self.axis = axis
|
||||||
|
|
||||||
|
def __call__(self, x) -> Any:
|
||||||
|
return glu(x=x, axis=self.axis)
|
||||||
|
|
||||||
|
|
||||||
def step(x: mx.array, threshold: float = 0.0):
|
def step(x: mx.array, threshold: float = 0.0):
|
||||||
r"""Applies the Step Activation Function.
|
r"""Applies the Step Activation Function.
|
||||||
|
|
||||||
|
@ -789,6 +789,12 @@ class TestLayers(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(y.shape, [5])
|
self.assertEqual(y.shape, [5])
|
||||||
self.assertEqual(y.dtype, mx.float32)
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
def test_glu(self):
|
||||||
|
x = mx.array([[[1.0, 2.0, 3.0, 4.0]]], dtype=mx.float32)
|
||||||
|
y = mx.array([[[0.952574, 1.96403]]], dtype=mx.float32)
|
||||||
|
out = nn.glu(x)
|
||||||
|
self.assertEqualArray(out, y)
|
||||||
|
|
||||||
def test_rope(self):
|
def test_rope(self):
|
||||||
for kwargs in [{}, {"traditional": False}, {"base": 10000}, {"scale": 0.25}]:
|
for kwargs in [{}, {"traditional": False}, {"base": 10000}, {"scale": 0.25}]:
|
||||||
rope = nn.RoPE(4, **kwargs)
|
rope = nn.RoPE(4, **kwargs)
|
||||||
|
Loading…
Reference in New Issue
Block a user