mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 10:48:09 +08:00
Add back compiled function signatures and docstrings (#749)
* try to add back compiled function signatures and docstrings * add indentation to docstring
This commit is contained in:
@@ -142,11 +142,11 @@ def log_sigmoid(x):
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def gelu(x):
|
||||
def gelu(x) -> mx.array:
|
||||
r"""Applies the Gaussian Error Linear Units function.
|
||||
|
||||
.. math::
|
||||
\\textrm{GELU}(x) = x * \Phi(x)
|
||||
\textrm{GELU}(x) = x * \Phi(x)
|
||||
|
||||
where :math:`\Phi(x)` is the Gaussian CDF.
|
||||
|
||||
@@ -203,7 +203,7 @@ def glu(x: mx.array, axis: int = -1) -> mx.array:
|
||||
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
|
||||
|
||||
.. math::
|
||||
textrm{GLU}(x) = a * \sigma(b)
|
||||
\textrm{GLU}(x) = a * \sigma(b)
|
||||
|
||||
Args:
|
||||
axis (int): The dimension to split along. Default: ``-1``
|
||||
@@ -264,6 +264,7 @@ def prelu(x: mx.array, alpha: mx.array) -> mx.array:
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def mish(x: mx.array) -> mx.array:
|
||||
r"""Applies the Mish function, element-wise.
|
||||
|
||||
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
|
||||
|
||||
Reference: https://arxiv.org/abs/1908.08681
|
||||
@@ -301,7 +302,7 @@ class GLU(Module):
|
||||
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
|
||||
|
||||
.. math::
|
||||
textrm{GLU}(x) = a * \sigma(b)
|
||||
\textrm{GLU}(x) = a * \sigma(b)
|
||||
|
||||
Args:
|
||||
axis (int): The dimension to split along. Default: ``-1``
|
||||
|
Reference in New Issue
Block a user