fix wraps compile (#2461)

This commit is contained in:
Awni Hannun
2025-08-04 16:14:18 -07:00
committed by GitHub
parent 6ad0889c8a
commit 0b807893a7
4 changed files with 48 additions and 38 deletions

View File

@@ -1414,35 +1414,8 @@ void init_transforms(nb::module_& m) {
const nb::object& inputs,
const nb::object& outputs,
bool shapeless) {
// Try to get the name
auto n =
nb::hasattr(fun, "__name__") ? fun.attr("__name__") : nb::none();
auto name = n.is_none() ? "compiled"
: nb::cast<std::string>(fun.attr("__name__"));
// Try to get the signature
std::ostringstream sig;
sig << "def " << name;
auto inspect = nb::module_::import_("inspect");
if (nb::cast<bool>(inspect.attr("isroutine")(fun))) {
sig << nb::cast<std::string>(
inspect.attr("signature")(fun).attr("__str__")());
} else {
sig << "(*args, **kwargs)";
}
// Try to get the doc string
auto d = inspect.attr("getdoc")(fun);
std::string doc =
d.is_none() ? "MLX compiled function." : nb::cast<std::string>(d);
auto sig_str = sig.str();
return mlx_func(
nb::cpp_function(
PyCompiledFun{fun, inputs, outputs, shapeless},
nb::name(name.c_str()),
nb::sig(sig_str.c_str()),
doc.c_str()),
nb::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless}),
fun,
inputs,
outputs);