diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index fe9014adc..962022e19 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -142,11 +142,11 @@ def log_sigmoid(x): @partial(mx.compile, shapeless=True) -def gelu(x): +def gelu(x) -> mx.array: r"""Applies the Gaussian Error Linear Units function. .. math:: - \\textrm{GELU}(x) = x * \Phi(x) + \textrm{GELU}(x) = x * \Phi(x) where :math:`\Phi(x)` is the Gaussian CDF. @@ -203,7 +203,7 @@ def glu(x: mx.array, axis: int = -1) -> mx.array: (:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`. .. math:: - textrm{GLU}(x) = a * \sigma(b) + \textrm{GLU}(x) = a * \sigma(b) Args: axis (int): The dimension to split along. Default: ``-1`` @@ -264,6 +264,7 @@ def prelu(x: mx.array, alpha: mx.array) -> mx.array: @partial(mx.compile, shapeless=True) def mish(x: mx.array) -> mx.array: r"""Applies the Mish function, element-wise. + Mish: A Self Regularized Non-Monotonic Neural Activation Function. Reference: https://arxiv.org/abs/1908.08681 @@ -301,7 +302,7 @@ class GLU(Module): (:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`. .. math:: - textrm{GLU}(x) = a * \sigma(b) + \textrm{GLU}(x) = a * \sigma(b) Args: axis (int): The dimension to split along. Default: ``-1`` diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 12c067be6..997261c36 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -800,7 +800,38 @@ void init_transforms(py::module_& m) { const py::object& inputs, const py::object& outputs, bool shapeless) { - return py::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless}); + py::options options; + options.disable_function_signatures(); + + std::ostringstream doc; + auto name = fun.attr("__name__").cast(); + doc << name; + + // Try to get the signature + auto inspect = py::module::import("inspect"); + if (!inspect.attr("isbuiltin")(fun).cast()) { + doc << inspect.attr("signature")(fun) + .attr("__str__")() + .cast(); + } + + // Try to get the doc string + if (auto d = fun.attr("__doc__"); py::isinstance(d)) { + doc << "\n\n"; + auto dstr = d.cast(); + // Add spaces to match first line indentation with remainder of + // docstring + int i = 0; + for (int i = dstr.size() - 1; i >= 0 && dstr[i] == ' '; i--) { + doc << ' '; + } + doc << dstr; + } + auto doc_str = doc.str(); + return py::cpp_function( + PyCompiledFun{fun, inputs, outputs, shapeless}, + py::name(name.c_str()), + py::doc(doc_str.c_str())); }, "fun"_a, "inputs"_a = std::nullopt,