fix wraps compile (#2461)

This commit is contained in:
Awni Hannun 2025-08-04 16:14:18 -07:00 committed by GitHub
parent 6ad0889c8a
commit 0b807893a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 48 additions and 38 deletions

View File

@ -9,8 +9,11 @@ struct gc_func {
PyObject_HEAD
// Vector call implementation that forwards calls to nanobind
PyObject* (*vectorcall)(PyObject*, PyObject* const*, size_t, PyObject*);
// The function itself
// The nanobind wrapper func
PyObject* func;
// The original wrapped func
PyObject* orig_func;
// A non-owning reference to dependencies owned by 'func'
std::vector<PyObject*> deps;
};
@ -68,8 +71,7 @@ static PyGetSetDef gc_func_getset[] = {
static PyObject* gc_func_getattro(PyObject* self, PyObject* name_) {
gc_func* w = (gc_func*)self;
auto f = PyCFunction(PyType_GetSlot(Py_TYPE(w->func), Py_tp_getattro));
return f(w->func, name_);
return PyObject_GenericGetAttr(w->orig_func, name_);
}
// Table of custom type slots we want to install
@ -92,9 +94,14 @@ static PyType_Spec gc_func_spec = {
static PyTypeObject* gc_func_tp = nullptr;
nb::callable mlx_func(nb::object func, std::vector<PyObject*> deps) {
nb::callable mlx_func(
nb::object func,
const nb::callable& orig_func,
std::vector<PyObject*> deps) {
gc_func* r = (gc_func*)PyType_GenericAlloc(gc_func_tp, 0);
r->func = func.inc_ref().ptr();
r->orig_func = orig_func.ptr();
deps.push_back(r->orig_func);
r->deps = std::move(deps);
r->vectorcall = gc_func_vectorcall;
return nb::steal<nb::callable>((PyObject*)r);

View File

@ -10,15 +10,22 @@
namespace nb = nanobind;
using namespace nb::literals;
nb::callable mlx_func(nb::object func, std::vector<PyObject*> deps);
nb::callable mlx_func(
nb::object func,
const nb::callable& orig_func,
std::vector<PyObject*> deps);
template <typename F, typename... Deps>
nb::callable mlx_func(F func, Deps&&... deps) {
nb::callable mlx_func(F func, const nb::callable& orig_func, Deps&&... deps) {
return mlx_func(
nb::cpp_function(std::move(func)), std::vector<PyObject*>{deps.ptr()...});
nb::cpp_function(std::move(func)),
orig_func,
std::vector<PyObject*>{deps.ptr()...});
}
template <typename... Deps>
nb::callable mlx_func(nb::object func, Deps&&... deps) {
return mlx_func(std::move(func), std::vector<PyObject*>{deps.ptr()...});
nb::callable
mlx_func(nb::object func, const nb::callable& orig_func, Deps&&... deps) {
return mlx_func(
std::move(func), orig_func, std::vector<PyObject*>{deps.ptr()...});
}

View File

@ -1414,35 +1414,8 @@ void init_transforms(nb::module_& m) {
const nb::object& inputs,
const nb::object& outputs,
bool shapeless) {
// Try to get the name
auto n =
nb::hasattr(fun, "__name__") ? fun.attr("__name__") : nb::none();
auto name = n.is_none() ? "compiled"
: nb::cast<std::string>(fun.attr("__name__"));
// Try to get the signature
std::ostringstream sig;
sig << "def " << name;
auto inspect = nb::module_::import_("inspect");
if (nb::cast<bool>(inspect.attr("isroutine")(fun))) {
sig << nb::cast<std::string>(
inspect.attr("signature")(fun).attr("__str__")());
} else {
sig << "(*args, **kwargs)";
}
// Try to get the doc string
auto d = inspect.attr("getdoc")(fun);
std::string doc =
d.is_none() ? "MLX compiled function." : nb::cast<std::string>(d);
auto sig_str = sig.str();
return mlx_func(
nb::cpp_function(
PyCompiledFun{fun, inputs, outputs, shapeless},
nb::name(name.c_str()),
nb::sig(sig_str.c_str()),
doc.c_str()),
nb::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless}),
fun,
inputs,
outputs);

View File

@ -1,10 +1,11 @@
# Copyright © 2023-2024 Apple Inc.
import gc
import inspect
import io
import math
import unittest
from functools import partial
from functools import partial, wraps
from io import StringIO
import mlx.core as mx
@ -1014,6 +1015,28 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(d[0], d_hat[0]))
self.assertTrue(mx.allclose(d[1], d_hat[1]))
def test_wrap_compiled(self):
@mx.compile
def inner():
pass
@wraps(inner)
def wrapper():
pass
def test_compiled_preserves_attributes(self):
def inner(x: mx.array, y: str):
"""
A useful function.
"""
pass
c_inner = mx.compile(inner)
self.assertEqual(inner.__name__, c_inner.__name__)
self.assertEqual(inner.__qualname__, c_inner.__qualname__)
self.assertEqual(inner.__doc__, c_inner.__doc__)
self.assertEqual(inspect.signature(inner), inspect.signature(c_inner))
if __name__ == "__main__":
mlx_tests.MLXTestRunner()