mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-29 21:11:16 +08:00
Add back compiled function signatures and docstrings (#749)
* try to add back compiled function signatures and docstrings * add indentation to docstring
This commit is contained in:
parent
56ba3ec40e
commit
420ff2f331
@ -142,11 +142,11 @@ def log_sigmoid(x):
|
|||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, shapeless=True)
|
@partial(mx.compile, shapeless=True)
|
||||||
def gelu(x):
|
def gelu(x) -> mx.array:
|
||||||
r"""Applies the Gaussian Error Linear Units function.
|
r"""Applies the Gaussian Error Linear Units function.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
\\textrm{GELU}(x) = x * \Phi(x)
|
\textrm{GELU}(x) = x * \Phi(x)
|
||||||
|
|
||||||
where :math:`\Phi(x)` is the Gaussian CDF.
|
where :math:`\Phi(x)` is the Gaussian CDF.
|
||||||
|
|
||||||
@ -203,7 +203,7 @@ def glu(x: mx.array, axis: int = -1) -> mx.array:
|
|||||||
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
|
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
textrm{GLU}(x) = a * \sigma(b)
|
\textrm{GLU}(x) = a * \sigma(b)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
axis (int): The dimension to split along. Default: ``-1``
|
axis (int): The dimension to split along. Default: ``-1``
|
||||||
@ -264,6 +264,7 @@ def prelu(x: mx.array, alpha: mx.array) -> mx.array:
|
|||||||
@partial(mx.compile, shapeless=True)
|
@partial(mx.compile, shapeless=True)
|
||||||
def mish(x: mx.array) -> mx.array:
|
def mish(x: mx.array) -> mx.array:
|
||||||
r"""Applies the Mish function, element-wise.
|
r"""Applies the Mish function, element-wise.
|
||||||
|
|
||||||
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
|
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
|
||||||
|
|
||||||
Reference: https://arxiv.org/abs/1908.08681
|
Reference: https://arxiv.org/abs/1908.08681
|
||||||
@ -301,7 +302,7 @@ class GLU(Module):
|
|||||||
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
|
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
textrm{GLU}(x) = a * \sigma(b)
|
\textrm{GLU}(x) = a * \sigma(b)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
axis (int): The dimension to split along. Default: ``-1``
|
axis (int): The dimension to split along. Default: ``-1``
|
||||||
|
@ -800,7 +800,38 @@ void init_transforms(py::module_& m) {
|
|||||||
const py::object& inputs,
|
const py::object& inputs,
|
||||||
const py::object& outputs,
|
const py::object& outputs,
|
||||||
bool shapeless) {
|
bool shapeless) {
|
||||||
return py::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless});
|
py::options options;
|
||||||
|
options.disable_function_signatures();
|
||||||
|
|
||||||
|
std::ostringstream doc;
|
||||||
|
auto name = fun.attr("__name__").cast<std::string>();
|
||||||
|
doc << name;
|
||||||
|
|
||||||
|
// Try to get the signature
|
||||||
|
auto inspect = py::module::import("inspect");
|
||||||
|
if (!inspect.attr("isbuiltin")(fun).cast<bool>()) {
|
||||||
|
doc << inspect.attr("signature")(fun)
|
||||||
|
.attr("__str__")()
|
||||||
|
.cast<std::string>();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to get the doc string
|
||||||
|
if (auto d = fun.attr("__doc__"); py::isinstance<py::str>(d)) {
|
||||||
|
doc << "\n\n";
|
||||||
|
auto dstr = d.cast<std::string>();
|
||||||
|
// Add spaces to match first line indentation with remainder of
|
||||||
|
// docstring
|
||||||
|
int i = 0;
|
||||||
|
for (int i = dstr.size() - 1; i >= 0 && dstr[i] == ' '; i--) {
|
||||||
|
doc << ' ';
|
||||||
|
}
|
||||||
|
doc << dstr;
|
||||||
|
}
|
||||||
|
auto doc_str = doc.str();
|
||||||
|
return py::cpp_function(
|
||||||
|
PyCompiledFun{fun, inputs, outputs, shapeless},
|
||||||
|
py::name(name.c_str()),
|
||||||
|
py::doc(doc_str.c_str()));
|
||||||
},
|
},
|
||||||
"fun"_a,
|
"fun"_a,
|
||||||
"inputs"_a = std::nullopt,
|
"inputs"_a = std::nullopt,
|
||||||
|
Loading…
Reference in New Issue
Block a user