mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-28 20:41:15 +08:00
Add back compiled function signatures and docstrings (#749)
* try to add back compiled function signatures and docstrings * add indentation to docstring
This commit is contained in:
parent
56ba3ec40e
commit
420ff2f331
@ -142,11 +142,11 @@ def log_sigmoid(x):
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def gelu(x):
|
||||
def gelu(x) -> mx.array:
|
||||
r"""Applies the Gaussian Error Linear Units function.
|
||||
|
||||
.. math::
|
||||
\\textrm{GELU}(x) = x * \Phi(x)
|
||||
\textrm{GELU}(x) = x * \Phi(x)
|
||||
|
||||
where :math:`\Phi(x)` is the Gaussian CDF.
|
||||
|
||||
@ -203,7 +203,7 @@ def glu(x: mx.array, axis: int = -1) -> mx.array:
|
||||
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
|
||||
|
||||
.. math::
|
||||
textrm{GLU}(x) = a * \sigma(b)
|
||||
\textrm{GLU}(x) = a * \sigma(b)
|
||||
|
||||
Args:
|
||||
axis (int): The dimension to split along. Default: ``-1``
|
||||
@ -264,6 +264,7 @@ def prelu(x: mx.array, alpha: mx.array) -> mx.array:
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def mish(x: mx.array) -> mx.array:
|
||||
r"""Applies the Mish function, element-wise.
|
||||
|
||||
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
|
||||
|
||||
Reference: https://arxiv.org/abs/1908.08681
|
||||
@ -301,7 +302,7 @@ class GLU(Module):
|
||||
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
|
||||
|
||||
.. math::
|
||||
textrm{GLU}(x) = a * \sigma(b)
|
||||
\textrm{GLU}(x) = a * \sigma(b)
|
||||
|
||||
Args:
|
||||
axis (int): The dimension to split along. Default: ``-1``
|
||||
|
@ -800,7 +800,38 @@ void init_transforms(py::module_& m) {
|
||||
const py::object& inputs,
|
||||
const py::object& outputs,
|
||||
bool shapeless) {
|
||||
return py::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless});
|
||||
py::options options;
|
||||
options.disable_function_signatures();
|
||||
|
||||
std::ostringstream doc;
|
||||
auto name = fun.attr("__name__").cast<std::string>();
|
||||
doc << name;
|
||||
|
||||
// Try to get the signature
|
||||
auto inspect = py::module::import("inspect");
|
||||
if (!inspect.attr("isbuiltin")(fun).cast<bool>()) {
|
||||
doc << inspect.attr("signature")(fun)
|
||||
.attr("__str__")()
|
||||
.cast<std::string>();
|
||||
}
|
||||
|
||||
// Try to get the doc string
|
||||
if (auto d = fun.attr("__doc__"); py::isinstance<py::str>(d)) {
|
||||
doc << "\n\n";
|
||||
auto dstr = d.cast<std::string>();
|
||||
// Add spaces to match first line indentation with remainder of
|
||||
// docstring
|
||||
int i = 0;
|
||||
for (int i = dstr.size() - 1; i >= 0 && dstr[i] == ' '; i--) {
|
||||
doc << ' ';
|
||||
}
|
||||
doc << dstr;
|
||||
}
|
||||
auto doc_str = doc.str();
|
||||
return py::cpp_function(
|
||||
PyCompiledFun{fun, inputs, outputs, shapeless},
|
||||
py::name(name.c_str()),
|
||||
py::doc(doc_str.c_str()));
|
||||
},
|
||||
"fun"_a,
|
||||
"inputs"_a = std::nullopt,
|
||||
|
Loading…
Reference in New Issue
Block a user