diff --git a/python/src/mlx_func.cpp b/python/src/mlx_func.cpp index 2f0589bb6f..87912fdda3 100644 --- a/python/src/mlx_func.cpp +++ b/python/src/mlx_func.cpp @@ -9,8 +9,11 @@ struct gc_func { PyObject_HEAD // Vector call implementation that forwards calls to nanobind PyObject* (*vectorcall)(PyObject*, PyObject* const*, size_t, PyObject*); - // The function itself + // The nanobind wrapper func PyObject* func; + + // The original wrapped func + PyObject* orig_func; // A non-owning reference to dependencies owned by 'func' std::vector deps; }; @@ -68,8 +71,7 @@ static PyGetSetDef gc_func_getset[] = { static PyObject* gc_func_getattro(PyObject* self, PyObject* name_) { gc_func* w = (gc_func*)self; - auto f = PyCFunction(PyType_GetSlot(Py_TYPE(w->func), Py_tp_getattro)); - return f(w->func, name_); + return PyObject_GenericGetAttr(w->orig_func, name_); } // Table of custom type slots we want to install @@ -92,9 +94,14 @@ static PyType_Spec gc_func_spec = { static PyTypeObject* gc_func_tp = nullptr; -nb::callable mlx_func(nb::object func, std::vector deps) { +nb::callable mlx_func( + nb::object func, + const nb::callable& orig_func, + std::vector deps) { gc_func* r = (gc_func*)PyType_GenericAlloc(gc_func_tp, 0); r->func = func.inc_ref().ptr(); + r->orig_func = orig_func.ptr(); + deps.push_back(r->orig_func); r->deps = std::move(deps); r->vectorcall = gc_func_vectorcall; return nb::steal((PyObject*)r); diff --git a/python/src/mlx_func.h b/python/src/mlx_func.h index 90cfb57bec..79c8376dee 100644 --- a/python/src/mlx_func.h +++ b/python/src/mlx_func.h @@ -10,15 +10,22 @@ namespace nb = nanobind; using namespace nb::literals; -nb::callable mlx_func(nb::object func, std::vector deps); +nb::callable mlx_func( + nb::object func, + const nb::callable& orig_func, + std::vector deps); template -nb::callable mlx_func(F func, Deps&&... deps) { +nb::callable mlx_func(F func, const nb::callable& orig_func, Deps&&... deps) { return mlx_func( - nb::cpp_function(std::move(func)), std::vector{deps.ptr()...}); + nb::cpp_function(std::move(func)), + orig_func, + std::vector{deps.ptr()...}); } template -nb::callable mlx_func(nb::object func, Deps&&... deps) { - return mlx_func(std::move(func), std::vector{deps.ptr()...}); +nb::callable +mlx_func(nb::object func, const nb::callable& orig_func, Deps&&... deps) { + return mlx_func( + std::move(func), orig_func, std::vector{deps.ptr()...}); } diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index c47942b720..2506f50b06 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -1414,35 +1414,8 @@ void init_transforms(nb::module_& m) { const nb::object& inputs, const nb::object& outputs, bool shapeless) { - // Try to get the name - auto n = - nb::hasattr(fun, "__name__") ? fun.attr("__name__") : nb::none(); - auto name = n.is_none() ? "compiled" - : nb::cast(fun.attr("__name__")); - - // Try to get the signature - std::ostringstream sig; - sig << "def " << name; - auto inspect = nb::module_::import_("inspect"); - if (nb::cast(inspect.attr("isroutine")(fun))) { - sig << nb::cast( - inspect.attr("signature")(fun).attr("__str__")()); - } else { - sig << "(*args, **kwargs)"; - } - - // Try to get the doc string - auto d = inspect.attr("getdoc")(fun); - std::string doc = - d.is_none() ? "MLX compiled function." : nb::cast(d); - - auto sig_str = sig.str(); return mlx_func( - nb::cpp_function( - PyCompiledFun{fun, inputs, outputs, shapeless}, - nb::name(name.c_str()), - nb::sig(sig_str.c_str()), - doc.c_str()), + nb::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless}), fun, inputs, outputs); diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index ada2b1484c..5eb32ce4d8 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -1,10 +1,11 @@ # Copyright © 2023-2024 Apple Inc. import gc +import inspect import io import math import unittest -from functools import partial +from functools import partial, wraps from io import StringIO import mlx.core as mx @@ -1014,6 +1015,28 @@ class TestCompile(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(d[0], d_hat[0])) self.assertTrue(mx.allclose(d[1], d_hat[1])) + def test_wrap_compiled(self): + @mx.compile + def inner(): + pass + + @wraps(inner) + def wrapper(): + pass + + def test_compiled_preserves_attributes(self): + def inner(x: mx.array, y: str): + """ + A useful function. + """ + pass + + c_inner = mx.compile(inner) + self.assertEqual(inner.__name__, c_inner.__name__) + self.assertEqual(inner.__qualname__, c_inner.__qualname__) + self.assertEqual(inner.__doc__, c_inner.__doc__) + self.assertEqual(inspect.signature(inner), inspect.signature(c_inner)) + if __name__ == "__main__": mlx_tests.MLXTestRunner()