Add back compiled function signatures and docstrings (#749)

* try to add back compiled function signatures and docstrings

* add indentation to docstring
This commit is contained in:
Awni Hannun
2024-02-27 13:18:59 -08:00
committed by GitHub
parent 56ba3ec40e
commit 420ff2f331
2 changed files with 37 additions and 5 deletions

View File

@@ -142,11 +142,11 @@ def log_sigmoid(x):
@partial(mx.compile, shapeless=True)
def gelu(x):
def gelu(x) -> mx.array:
r"""Applies the Gaussian Error Linear Units function.
.. math::
\\textrm{GELU}(x) = x * \Phi(x)
\textrm{GELU}(x) = x * \Phi(x)
where :math:`\Phi(x)` is the Gaussian CDF.
@@ -203,7 +203,7 @@ def glu(x: mx.array, axis: int = -1) -> mx.array:
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
.. math::
textrm{GLU}(x) = a * \sigma(b)
\textrm{GLU}(x) = a * \sigma(b)
Args:
axis (int): The dimension to split along. Default: ``-1``
@@ -264,6 +264,7 @@ def prelu(x: mx.array, alpha: mx.array) -> mx.array:
@partial(mx.compile, shapeless=True)
def mish(x: mx.array) -> mx.array:
r"""Applies the Mish function, element-wise.
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
Reference: https://arxiv.org/abs/1908.08681
@@ -301,7 +302,7 @@ class GLU(Module):
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
.. math::
textrm{GLU}(x) = a * \sigma(b)
\textrm{GLU}(x) = a * \sigma(b)
Args:
axis (int): The dimension to split along. Default: ``-1``