mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix wraps compile (#2461)
This commit is contained in:
@@ -1414,35 +1414,8 @@ void init_transforms(nb::module_& m) {
|
||||
const nb::object& inputs,
|
||||
const nb::object& outputs,
|
||||
bool shapeless) {
|
||||
// Try to get the name
|
||||
auto n =
|
||||
nb::hasattr(fun, "__name__") ? fun.attr("__name__") : nb::none();
|
||||
auto name = n.is_none() ? "compiled"
|
||||
: nb::cast<std::string>(fun.attr("__name__"));
|
||||
|
||||
// Try to get the signature
|
||||
std::ostringstream sig;
|
||||
sig << "def " << name;
|
||||
auto inspect = nb::module_::import_("inspect");
|
||||
if (nb::cast<bool>(inspect.attr("isroutine")(fun))) {
|
||||
sig << nb::cast<std::string>(
|
||||
inspect.attr("signature")(fun).attr("__str__")());
|
||||
} else {
|
||||
sig << "(*args, **kwargs)";
|
||||
}
|
||||
|
||||
// Try to get the doc string
|
||||
auto d = inspect.attr("getdoc")(fun);
|
||||
std::string doc =
|
||||
d.is_none() ? "MLX compiled function." : nb::cast<std::string>(d);
|
||||
|
||||
auto sig_str = sig.str();
|
||||
return mlx_func(
|
||||
nb::cpp_function(
|
||||
PyCompiledFun{fun, inputs, outputs, shapeless},
|
||||
nb::name(name.c_str()),
|
||||
nb::sig(sig_str.c_str()),
|
||||
doc.c_str()),
|
||||
nb::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless}),
|
||||
fun,
|
||||
inputs,
|
||||
outputs);
|
||||
|
||||
Reference in New Issue
Block a user