mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-05 11:28:12 +08:00
fix wraps compile (#2461)
This commit is contained in:
@@ -9,8 +9,11 @@ struct gc_func {
|
||||
PyObject_HEAD
|
||||
// Vector call implementation that forwards calls to nanobind
|
||||
PyObject* (*vectorcall)(PyObject*, PyObject* const*, size_t, PyObject*);
|
||||
// The function itself
|
||||
// The nanobind wrapper func
|
||||
PyObject* func;
|
||||
|
||||
// The original wrapped func
|
||||
PyObject* orig_func;
|
||||
// A non-owning reference to dependencies owned by 'func'
|
||||
std::vector<PyObject*> deps;
|
||||
};
|
||||
@@ -68,8 +71,7 @@ static PyGetSetDef gc_func_getset[] = {
|
||||
|
||||
static PyObject* gc_func_getattro(PyObject* self, PyObject* name_) {
|
||||
gc_func* w = (gc_func*)self;
|
||||
auto f = PyCFunction(PyType_GetSlot(Py_TYPE(w->func), Py_tp_getattro));
|
||||
return f(w->func, name_);
|
||||
return PyObject_GenericGetAttr(w->orig_func, name_);
|
||||
}
|
||||
|
||||
// Table of custom type slots we want to install
|
||||
@@ -92,9 +94,14 @@ static PyType_Spec gc_func_spec = {
|
||||
|
||||
static PyTypeObject* gc_func_tp = nullptr;
|
||||
|
||||
nb::callable mlx_func(nb::object func, std::vector<PyObject*> deps) {
|
||||
nb::callable mlx_func(
|
||||
nb::object func,
|
||||
const nb::callable& orig_func,
|
||||
std::vector<PyObject*> deps) {
|
||||
gc_func* r = (gc_func*)PyType_GenericAlloc(gc_func_tp, 0);
|
||||
r->func = func.inc_ref().ptr();
|
||||
r->orig_func = orig_func.ptr();
|
||||
deps.push_back(r->orig_func);
|
||||
r->deps = std::move(deps);
|
||||
r->vectorcall = gc_func_vectorcall;
|
||||
return nb::steal<nb::callable>((PyObject*)r);
|
||||
|
||||
@@ -10,15 +10,22 @@
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
nb::callable mlx_func(nb::object func, std::vector<PyObject*> deps);
|
||||
nb::callable mlx_func(
|
||||
nb::object func,
|
||||
const nb::callable& orig_func,
|
||||
std::vector<PyObject*> deps);
|
||||
|
||||
template <typename F, typename... Deps>
|
||||
nb::callable mlx_func(F func, Deps&&... deps) {
|
||||
nb::callable mlx_func(F func, const nb::callable& orig_func, Deps&&... deps) {
|
||||
return mlx_func(
|
||||
nb::cpp_function(std::move(func)), std::vector<PyObject*>{deps.ptr()...});
|
||||
nb::cpp_function(std::move(func)),
|
||||
orig_func,
|
||||
std::vector<PyObject*>{deps.ptr()...});
|
||||
}
|
||||
|
||||
template <typename... Deps>
|
||||
nb::callable mlx_func(nb::object func, Deps&&... deps) {
|
||||
return mlx_func(std::move(func), std::vector<PyObject*>{deps.ptr()...});
|
||||
nb::callable
|
||||
mlx_func(nb::object func, const nb::callable& orig_func, Deps&&... deps) {
|
||||
return mlx_func(
|
||||
std::move(func), orig_func, std::vector<PyObject*>{deps.ptr()...});
|
||||
}
|
||||
|
||||
@@ -1414,35 +1414,8 @@ void init_transforms(nb::module_& m) {
|
||||
const nb::object& inputs,
|
||||
const nb::object& outputs,
|
||||
bool shapeless) {
|
||||
// Try to get the name
|
||||
auto n =
|
||||
nb::hasattr(fun, "__name__") ? fun.attr("__name__") : nb::none();
|
||||
auto name = n.is_none() ? "compiled"
|
||||
: nb::cast<std::string>(fun.attr("__name__"));
|
||||
|
||||
// Try to get the signature
|
||||
std::ostringstream sig;
|
||||
sig << "def " << name;
|
||||
auto inspect = nb::module_::import_("inspect");
|
||||
if (nb::cast<bool>(inspect.attr("isroutine")(fun))) {
|
||||
sig << nb::cast<std::string>(
|
||||
inspect.attr("signature")(fun).attr("__str__")());
|
||||
} else {
|
||||
sig << "(*args, **kwargs)";
|
||||
}
|
||||
|
||||
// Try to get the doc string
|
||||
auto d = inspect.attr("getdoc")(fun);
|
||||
std::string doc =
|
||||
d.is_none() ? "MLX compiled function." : nb::cast<std::string>(d);
|
||||
|
||||
auto sig_str = sig.str();
|
||||
return mlx_func(
|
||||
nb::cpp_function(
|
||||
PyCompiledFun{fun, inputs, outputs, shapeless},
|
||||
nb::name(name.c_str()),
|
||||
nb::sig(sig_str.c_str()),
|
||||
doc.c_str()),
|
||||
nb::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless}),
|
||||
fun,
|
||||
inputs,
|
||||
outputs);
|
||||
|
||||
Reference in New Issue
Block a user