Add back compiled function signatures and docstrings (#749)

* try to add back compiled function signatures and docstrings

* add indentation to docstring
This commit is contained in:
Awni Hannun 2024-02-27 13:18:59 -08:00 committed by GitHub
parent 56ba3ec40e
commit 420ff2f331
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 37 additions and 5 deletions

View File

@ -142,11 +142,11 @@ def log_sigmoid(x):
@partial(mx.compile, shapeless=True)
def gelu(x):
def gelu(x) -> mx.array:
r"""Applies the Gaussian Error Linear Units function.
.. math::
\\textrm{GELU}(x) = x * \Phi(x)
\textrm{GELU}(x) = x * \Phi(x)
where :math:`\Phi(x)` is the Gaussian CDF.
@ -203,7 +203,7 @@ def glu(x: mx.array, axis: int = -1) -> mx.array:
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
.. math::
textrm{GLU}(x) = a * \sigma(b)
\textrm{GLU}(x) = a * \sigma(b)
Args:
axis (int): The dimension to split along. Default: ``-1``
@ -264,6 +264,7 @@ def prelu(x: mx.array, alpha: mx.array) -> mx.array:
@partial(mx.compile, shapeless=True)
def mish(x: mx.array) -> mx.array:
r"""Applies the Mish function, element-wise.
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
Reference: https://arxiv.org/abs/1908.08681
@ -301,7 +302,7 @@ class GLU(Module):
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
.. math::
textrm{GLU}(x) = a * \sigma(b)
\textrm{GLU}(x) = a * \sigma(b)
Args:
axis (int): The dimension to split along. Default: ``-1``

View File

@ -800,7 +800,38 @@ void init_transforms(py::module_& m) {
const py::object& inputs,
const py::object& outputs,
bool shapeless) {
return py::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless});
py::options options;
options.disable_function_signatures();
std::ostringstream doc;
auto name = fun.attr("__name__").cast<std::string>();
doc << name;
// Try to get the signature
auto inspect = py::module::import("inspect");
if (!inspect.attr("isbuiltin")(fun).cast<bool>()) {
doc << inspect.attr("signature")(fun)
.attr("__str__")()
.cast<std::string>();
}
// Try to get the doc string
if (auto d = fun.attr("__doc__"); py::isinstance<py::str>(d)) {
doc << "\n\n";
auto dstr = d.cast<std::string>();
// Add spaces to match first line indentation with remainder of
// docstring
int i = 0;
for (int i = dstr.size() - 1; i >= 0 && dstr[i] == ' '; i--) {
doc << ' ';
}
doc << dstr;
}
auto doc_str = doc.str();
return py::cpp_function(
PyCompiledFun{fun, inputs, outputs, shapeless},
py::name(name.c_str()),
py::doc(doc_str.c_str()));
},
"fun"_a,
"inputs"_a = std::nullopt,