fix wraps compile (#2461)

This commit is contained in:
Awni Hannun 2025-08-04 16:14:18 -07:00 committed by GitHub
parent 6ad0889c8a
commit 0b807893a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 48 additions and 38 deletions

View File

@ -9,8 +9,11 @@ struct gc_func {
PyObject_HEAD PyObject_HEAD
// Vector call implementation that forwards calls to nanobind // Vector call implementation that forwards calls to nanobind
PyObject* (*vectorcall)(PyObject*, PyObject* const*, size_t, PyObject*); PyObject* (*vectorcall)(PyObject*, PyObject* const*, size_t, PyObject*);
// The function itself // The nanobind wrapper func
PyObject* func; PyObject* func;
// The original wrapped func
PyObject* orig_func;
// A non-owning reference to dependencies owned by 'func' // A non-owning reference to dependencies owned by 'func'
std::vector<PyObject*> deps; std::vector<PyObject*> deps;
}; };
@ -68,8 +71,7 @@ static PyGetSetDef gc_func_getset[] = {
static PyObject* gc_func_getattro(PyObject* self, PyObject* name_) { static PyObject* gc_func_getattro(PyObject* self, PyObject* name_) {
gc_func* w = (gc_func*)self; gc_func* w = (gc_func*)self;
auto f = PyCFunction(PyType_GetSlot(Py_TYPE(w->func), Py_tp_getattro)); return PyObject_GenericGetAttr(w->orig_func, name_);
return f(w->func, name_);
} }
// Table of custom type slots we want to install // Table of custom type slots we want to install
@ -92,9 +94,14 @@ static PyType_Spec gc_func_spec = {
static PyTypeObject* gc_func_tp = nullptr; static PyTypeObject* gc_func_tp = nullptr;
nb::callable mlx_func(nb::object func, std::vector<PyObject*> deps) { nb::callable mlx_func(
nb::object func,
const nb::callable& orig_func,
std::vector<PyObject*> deps) {
gc_func* r = (gc_func*)PyType_GenericAlloc(gc_func_tp, 0); gc_func* r = (gc_func*)PyType_GenericAlloc(gc_func_tp, 0);
r->func = func.inc_ref().ptr(); r->func = func.inc_ref().ptr();
r->orig_func = orig_func.ptr();
deps.push_back(r->orig_func);
r->deps = std::move(deps); r->deps = std::move(deps);
r->vectorcall = gc_func_vectorcall; r->vectorcall = gc_func_vectorcall;
return nb::steal<nb::callable>((PyObject*)r); return nb::steal<nb::callable>((PyObject*)r);

View File

@ -10,15 +10,22 @@
namespace nb = nanobind; namespace nb = nanobind;
using namespace nb::literals; using namespace nb::literals;
nb::callable mlx_func(nb::object func, std::vector<PyObject*> deps); nb::callable mlx_func(
nb::object func,
const nb::callable& orig_func,
std::vector<PyObject*> deps);
template <typename F, typename... Deps> template <typename F, typename... Deps>
nb::callable mlx_func(F func, Deps&&... deps) { nb::callable mlx_func(F func, const nb::callable& orig_func, Deps&&... deps) {
return mlx_func( return mlx_func(
nb::cpp_function(std::move(func)), std::vector<PyObject*>{deps.ptr()...}); nb::cpp_function(std::move(func)),
orig_func,
std::vector<PyObject*>{deps.ptr()...});
} }
template <typename... Deps> template <typename... Deps>
nb::callable mlx_func(nb::object func, Deps&&... deps) { nb::callable
return mlx_func(std::move(func), std::vector<PyObject*>{deps.ptr()...}); mlx_func(nb::object func, const nb::callable& orig_func, Deps&&... deps) {
return mlx_func(
std::move(func), orig_func, std::vector<PyObject*>{deps.ptr()...});
} }

View File

@ -1414,35 +1414,8 @@ void init_transforms(nb::module_& m) {
const nb::object& inputs, const nb::object& inputs,
const nb::object& outputs, const nb::object& outputs,
bool shapeless) { bool shapeless) {
// Try to get the name
auto n =
nb::hasattr(fun, "__name__") ? fun.attr("__name__") : nb::none();
auto name = n.is_none() ? "compiled"
: nb::cast<std::string>(fun.attr("__name__"));
// Try to get the signature
std::ostringstream sig;
sig << "def " << name;
auto inspect = nb::module_::import_("inspect");
if (nb::cast<bool>(inspect.attr("isroutine")(fun))) {
sig << nb::cast<std::string>(
inspect.attr("signature")(fun).attr("__str__")());
} else {
sig << "(*args, **kwargs)";
}
// Try to get the doc string
auto d = inspect.attr("getdoc")(fun);
std::string doc =
d.is_none() ? "MLX compiled function." : nb::cast<std::string>(d);
auto sig_str = sig.str();
return mlx_func( return mlx_func(
nb::cpp_function( nb::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless}),
PyCompiledFun{fun, inputs, outputs, shapeless},
nb::name(name.c_str()),
nb::sig(sig_str.c_str()),
doc.c_str()),
fun, fun,
inputs, inputs,
outputs); outputs);

View File

@ -1,10 +1,11 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import gc import gc
import inspect
import io import io
import math import math
import unittest import unittest
from functools import partial from functools import partial, wraps
from io import StringIO from io import StringIO
import mlx.core as mx import mlx.core as mx
@ -1014,6 +1015,28 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(d[0], d_hat[0])) self.assertTrue(mx.allclose(d[0], d_hat[0]))
self.assertTrue(mx.allclose(d[1], d_hat[1])) self.assertTrue(mx.allclose(d[1], d_hat[1]))
def test_wrap_compiled(self):
@mx.compile
def inner():
pass
@wraps(inner)
def wrapper():
pass
def test_compiled_preserves_attributes(self):
def inner(x: mx.array, y: str):
"""
A useful function.
"""
pass
c_inner = mx.compile(inner)
self.assertEqual(inner.__name__, c_inner.__name__)
self.assertEqual(inner.__qualname__, c_inner.__qualname__)
self.assertEqual(inner.__doc__, c_inner.__doc__)
self.assertEqual(inspect.signature(inner), inspect.signature(c_inner))
if __name__ == "__main__": if __name__ == "__main__":
mlx_tests.MLXTestRunner() mlx_tests.MLXTestRunner()