mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-11 19:56:40 +08:00
fix wraps compile (#2461)
This commit is contained in:
parent
6ad0889c8a
commit
0b807893a7
@ -9,8 +9,11 @@ struct gc_func {
|
||||
PyObject_HEAD
|
||||
// Vector call implementation that forwards calls to nanobind
|
||||
PyObject* (*vectorcall)(PyObject*, PyObject* const*, size_t, PyObject*);
|
||||
// The function itself
|
||||
// The nanobind wrapper func
|
||||
PyObject* func;
|
||||
|
||||
// The original wrapped func
|
||||
PyObject* orig_func;
|
||||
// A non-owning reference to dependencies owned by 'func'
|
||||
std::vector<PyObject*> deps;
|
||||
};
|
||||
@ -68,8 +71,7 @@ static PyGetSetDef gc_func_getset[] = {
|
||||
|
||||
static PyObject* gc_func_getattro(PyObject* self, PyObject* name_) {
|
||||
gc_func* w = (gc_func*)self;
|
||||
auto f = PyCFunction(PyType_GetSlot(Py_TYPE(w->func), Py_tp_getattro));
|
||||
return f(w->func, name_);
|
||||
return PyObject_GenericGetAttr(w->orig_func, name_);
|
||||
}
|
||||
|
||||
// Table of custom type slots we want to install
|
||||
@ -92,9 +94,14 @@ static PyType_Spec gc_func_spec = {
|
||||
|
||||
static PyTypeObject* gc_func_tp = nullptr;
|
||||
|
||||
nb::callable mlx_func(nb::object func, std::vector<PyObject*> deps) {
|
||||
nb::callable mlx_func(
|
||||
nb::object func,
|
||||
const nb::callable& orig_func,
|
||||
std::vector<PyObject*> deps) {
|
||||
gc_func* r = (gc_func*)PyType_GenericAlloc(gc_func_tp, 0);
|
||||
r->func = func.inc_ref().ptr();
|
||||
r->orig_func = orig_func.ptr();
|
||||
deps.push_back(r->orig_func);
|
||||
r->deps = std::move(deps);
|
||||
r->vectorcall = gc_func_vectorcall;
|
||||
return nb::steal<nb::callable>((PyObject*)r);
|
||||
|
@ -10,15 +10,22 @@
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
nb::callable mlx_func(nb::object func, std::vector<PyObject*> deps);
|
||||
nb::callable mlx_func(
|
||||
nb::object func,
|
||||
const nb::callable& orig_func,
|
||||
std::vector<PyObject*> deps);
|
||||
|
||||
template <typename F, typename... Deps>
|
||||
nb::callable mlx_func(F func, Deps&&... deps) {
|
||||
nb::callable mlx_func(F func, const nb::callable& orig_func, Deps&&... deps) {
|
||||
return mlx_func(
|
||||
nb::cpp_function(std::move(func)), std::vector<PyObject*>{deps.ptr()...});
|
||||
nb::cpp_function(std::move(func)),
|
||||
orig_func,
|
||||
std::vector<PyObject*>{deps.ptr()...});
|
||||
}
|
||||
|
||||
template <typename... Deps>
|
||||
nb::callable mlx_func(nb::object func, Deps&&... deps) {
|
||||
return mlx_func(std::move(func), std::vector<PyObject*>{deps.ptr()...});
|
||||
nb::callable
|
||||
mlx_func(nb::object func, const nb::callable& orig_func, Deps&&... deps) {
|
||||
return mlx_func(
|
||||
std::move(func), orig_func, std::vector<PyObject*>{deps.ptr()...});
|
||||
}
|
||||
|
@ -1414,35 +1414,8 @@ void init_transforms(nb::module_& m) {
|
||||
const nb::object& inputs,
|
||||
const nb::object& outputs,
|
||||
bool shapeless) {
|
||||
// Try to get the name
|
||||
auto n =
|
||||
nb::hasattr(fun, "__name__") ? fun.attr("__name__") : nb::none();
|
||||
auto name = n.is_none() ? "compiled"
|
||||
: nb::cast<std::string>(fun.attr("__name__"));
|
||||
|
||||
// Try to get the signature
|
||||
std::ostringstream sig;
|
||||
sig << "def " << name;
|
||||
auto inspect = nb::module_::import_("inspect");
|
||||
if (nb::cast<bool>(inspect.attr("isroutine")(fun))) {
|
||||
sig << nb::cast<std::string>(
|
||||
inspect.attr("signature")(fun).attr("__str__")());
|
||||
} else {
|
||||
sig << "(*args, **kwargs)";
|
||||
}
|
||||
|
||||
// Try to get the doc string
|
||||
auto d = inspect.attr("getdoc")(fun);
|
||||
std::string doc =
|
||||
d.is_none() ? "MLX compiled function." : nb::cast<std::string>(d);
|
||||
|
||||
auto sig_str = sig.str();
|
||||
return mlx_func(
|
||||
nb::cpp_function(
|
||||
PyCompiledFun{fun, inputs, outputs, shapeless},
|
||||
nb::name(name.c_str()),
|
||||
nb::sig(sig_str.c_str()),
|
||||
doc.c_str()),
|
||||
nb::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless}),
|
||||
fun,
|
||||
inputs,
|
||||
outputs);
|
||||
|
@ -1,10 +1,11 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import gc
|
||||
import inspect
|
||||
import io
|
||||
import math
|
||||
import unittest
|
||||
from functools import partial
|
||||
from functools import partial, wraps
|
||||
from io import StringIO
|
||||
|
||||
import mlx.core as mx
|
||||
@ -1014,6 +1015,28 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(d[0], d_hat[0]))
|
||||
self.assertTrue(mx.allclose(d[1], d_hat[1]))
|
||||
|
||||
def test_wrap_compiled(self):
|
||||
@mx.compile
|
||||
def inner():
|
||||
pass
|
||||
|
||||
@wraps(inner)
|
||||
def wrapper():
|
||||
pass
|
||||
|
||||
def test_compiled_preserves_attributes(self):
|
||||
def inner(x: mx.array, y: str):
|
||||
"""
|
||||
A useful function.
|
||||
"""
|
||||
pass
|
||||
|
||||
c_inner = mx.compile(inner)
|
||||
self.assertEqual(inner.__name__, c_inner.__name__)
|
||||
self.assertEqual(inner.__qualname__, c_inner.__qualname__)
|
||||
self.assertEqual(inner.__doc__, c_inner.__doc__)
|
||||
self.assertEqual(inspect.signature(inner), inspect.signature(c_inner))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
Loading…
Reference in New Issue
Block a user