mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Cycle leak break (#1856)
* detect and break leaks in custom function * detect and break leaks in custom function
This commit is contained in:
parent
142b77751d
commit
2a45056ba8
@ -17,6 +17,7 @@ nanobind_add_module(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/mlx_func.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
||||||
|
@ -60,8 +60,56 @@ validate_and_extract_inputs(
|
|||||||
return {args_, kwargs_};
|
return {args_, kwargs_};
|
||||||
}
|
}
|
||||||
|
|
||||||
auto wrap_export_function(const nb::callable& fun) {
|
int py_function_exporter_tp_traverse(
|
||||||
return [fun](
|
PyObject* self,
|
||||||
|
visitproc visit,
|
||||||
|
void* arg);
|
||||||
|
|
||||||
|
class PyFunctionExporter {
|
||||||
|
public:
|
||||||
|
PyFunctionExporter(mx::FunctionExporter exporter, nb::handle dep)
|
||||||
|
: exporter_(std::move(exporter)), dep_(dep) {}
|
||||||
|
~PyFunctionExporter() {
|
||||||
|
nb::gil_scoped_acquire gil;
|
||||||
|
}
|
||||||
|
PyFunctionExporter(const PyFunctionExporter&) = delete;
|
||||||
|
PyFunctionExporter& operator=(const PyFunctionExporter&) = delete;
|
||||||
|
PyFunctionExporter& operator=(const PyFunctionExporter&&) = delete;
|
||||||
|
PyFunctionExporter(PyFunctionExporter&& other)
|
||||||
|
: exporter_(std::move(other.exporter_)), dep_(std::move(other.dep_)) {}
|
||||||
|
|
||||||
|
void close() {
|
||||||
|
exporter_.close();
|
||||||
|
}
|
||||||
|
void operator()(
|
||||||
|
const std::vector<mx::array>& args,
|
||||||
|
const std::map<std::string, mx::array>& kwargs) {
|
||||||
|
exporter_(args, kwargs);
|
||||||
|
}
|
||||||
|
|
||||||
|
friend int py_function_exporter_tp_traverse(PyObject*, visitproc, void*);
|
||||||
|
|
||||||
|
private:
|
||||||
|
mx::FunctionExporter exporter_;
|
||||||
|
nb::handle dep_;
|
||||||
|
};
|
||||||
|
|
||||||
|
int py_function_exporter_tp_traverse(
|
||||||
|
PyObject* self,
|
||||||
|
visitproc visit,
|
||||||
|
void* arg) {
|
||||||
|
auto* p = nb::inst_ptr<PyFunctionExporter>(self);
|
||||||
|
Py_VISIT(p->dep_.ptr());
|
||||||
|
Py_VISIT(Py_TYPE(self));
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
PyType_Slot py_function_exporter_slots[] = {
|
||||||
|
{Py_tp_traverse, (void*)py_function_exporter_tp_traverse},
|
||||||
|
{0, 0}};
|
||||||
|
|
||||||
|
auto wrap_export_function(nb::callable fun) {
|
||||||
|
return [fun = std::move(fun)](
|
||||||
const std::vector<mx::array>& args_,
|
const std::vector<mx::array>& args_,
|
||||||
const std::map<std::string, mx::array>& kwargs_) {
|
const std::map<std::string, mx::array>& kwargs_) {
|
||||||
auto kwargs = nb::dict();
|
auto kwargs = nb::dict();
|
||||||
@ -173,21 +221,21 @@ void init_export(nb::module_& m) {
|
|||||||
>>> out = fn((a, b), {"x": x, "y": y}[0]
|
>>> out = fn((a, b), {"x": x, "y": y}[0]
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
|
||||||
nb::class_<mx::FunctionExporter>(
|
nb::class_<PyFunctionExporter>(
|
||||||
m,
|
m,
|
||||||
"FunctionExporter",
|
"FunctionExporter",
|
||||||
|
nb::type_slots(py_function_exporter_slots),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
A context managing class for exporting multiple traces of the same
|
A context managing class for exporting multiple traces of the same
|
||||||
function to a file.
|
function to a file.
|
||||||
|
|
||||||
Make an instance of this class by calling fun:`mx.exporter`.
|
Make an instance of this class by calling fun:`mx.exporter`.
|
||||||
)pbdoc")
|
)pbdoc")
|
||||||
.def("close", &mx::FunctionExporter::close)
|
.def("close", &PyFunctionExporter::close)
|
||||||
.def(
|
.def("__enter__", [](PyFunctionExporter& exporter) { return &exporter; })
|
||||||
"__enter__", [](mx::FunctionExporter& exporter) { return &exporter; })
|
|
||||||
.def(
|
.def(
|
||||||
"__exit__",
|
"__exit__",
|
||||||
[](mx::FunctionExporter& exporter,
|
[](PyFunctionExporter& exporter,
|
||||||
const std::optional<nb::object>&,
|
const std::optional<nb::object>&,
|
||||||
const std::optional<nb::object>&,
|
const std::optional<nb::object>&,
|
||||||
const std::optional<nb::object>&) { exporter.close(); },
|
const std::optional<nb::object>&) { exporter.close(); },
|
||||||
@ -196,7 +244,7 @@ void init_export(nb::module_& m) {
|
|||||||
"traceback"_a = nb::none())
|
"traceback"_a = nb::none())
|
||||||
.def(
|
.def(
|
||||||
"__call__",
|
"__call__",
|
||||||
[](mx::FunctionExporter& exporter,
|
[](PyFunctionExporter& exporter,
|
||||||
const nb::args& args,
|
const nb::args& args,
|
||||||
const nb::kwargs& kwargs) {
|
const nb::kwargs& kwargs) {
|
||||||
auto [args_, kwargs_] =
|
auto [args_, kwargs_] =
|
||||||
@ -206,8 +254,9 @@ void init_export(nb::module_& m) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"exporter",
|
"exporter",
|
||||||
[](const std::string& file, const nb::callable& fun, bool shapeless) {
|
[](const std::string& file, nb::callable fun, bool shapeless) {
|
||||||
return mx::exporter(file, wrap_export_function(fun), shapeless);
|
return PyFunctionExporter{
|
||||||
|
mx::exporter(file, wrap_export_function(fun), shapeless), fun};
|
||||||
},
|
},
|
||||||
"file"_a,
|
"file"_a,
|
||||||
"fun"_a,
|
"fun"_a,
|
||||||
|
@ -7,6 +7,7 @@
|
|||||||
|
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
|
|
||||||
|
void init_mlx_func(nb::module_&);
|
||||||
void init_array(nb::module_&);
|
void init_array(nb::module_&);
|
||||||
void init_device(nb::module_&);
|
void init_device(nb::module_&);
|
||||||
void init_stream(nb::module_&);
|
void init_stream(nb::module_&);
|
||||||
@ -28,6 +29,7 @@ NB_MODULE(core, m) {
|
|||||||
nb::module_::import_("mlx._os_warning");
|
nb::module_::import_("mlx._os_warning");
|
||||||
nb::set_leak_warnings(false);
|
nb::set_leak_warnings(false);
|
||||||
|
|
||||||
|
init_mlx_func(m);
|
||||||
init_device(m);
|
init_device(m);
|
||||||
init_stream(m);
|
init_stream(m);
|
||||||
init_array(m);
|
init_array(m);
|
||||||
|
108
python/src/mlx_func.cpp
Normal file
108
python/src/mlx_func.cpp
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "python/src/mlx_func.h"
|
||||||
|
|
||||||
|
// A garbage collected function which wraps nb::cpp_function
|
||||||
|
// See https://github.com/wjakob/nanobind/discussions/919
|
||||||
|
|
||||||
|
struct gc_func {
|
||||||
|
PyObject_HEAD
|
||||||
|
// Vector call implementation that forwards calls to nanobind
|
||||||
|
PyObject* (*vectorcall)(PyObject*, PyObject* const*, size_t, PyObject*);
|
||||||
|
// The function itself
|
||||||
|
PyObject* func;
|
||||||
|
// A non-owning reference to dependencies owned by 'func'
|
||||||
|
std::vector<PyObject*> deps;
|
||||||
|
};
|
||||||
|
|
||||||
|
int gc_func_tp_traverse(PyObject* self, visitproc visit, void* arg) {
|
||||||
|
gc_func* w = (gc_func*)self;
|
||||||
|
Py_VISIT(w->func);
|
||||||
|
for (auto d : w->deps) {
|
||||||
|
Py_VISIT(d);
|
||||||
|
}
|
||||||
|
Py_VISIT(Py_TYPE(self));
|
||||||
|
return 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
int gc_func_tp_clear(PyObject* self) {
|
||||||
|
gc_func* w = (gc_func*)self;
|
||||||
|
Py_CLEAR(w->func);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject* gc_func_get_doc(PyObject* self, void*) {
|
||||||
|
return PyObject_GetAttrString(((gc_func*)self)->func, "__doc__");
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject* gc_func_get_sig(PyObject* self, void*) {
|
||||||
|
return PyObject_GetAttrString(((gc_func*)self)->func, "__nb_signature__");
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject* gc_func_vectorcall(
|
||||||
|
PyObject* self,
|
||||||
|
PyObject* const* args,
|
||||||
|
size_t nargs,
|
||||||
|
PyObject* kwnames) {
|
||||||
|
return PyObject_Vectorcall(((gc_func*)self)->func, args, nargs, kwnames);
|
||||||
|
}
|
||||||
|
|
||||||
|
void gc_func_dealloc(PyObject* self) {
|
||||||
|
PyObject_GC_UnTrack(self);
|
||||||
|
Py_XDECREF(((gc_func*)self)->func);
|
||||||
|
PyObject_GC_Del(self);
|
||||||
|
}
|
||||||
|
|
||||||
|
static PyMemberDef gc_func_members[] = {
|
||||||
|
{"__vectorcalloffset__",
|
||||||
|
T_PYSSIZET,
|
||||||
|
(Py_ssize_t)offsetof(gc_func, vectorcall),
|
||||||
|
READONLY,
|
||||||
|
nullptr},
|
||||||
|
{nullptr, 0, 0, 0, nullptr}};
|
||||||
|
|
||||||
|
static PyGetSetDef gc_func_getset[] = {
|
||||||
|
{"__doc__", gc_func_get_doc, nullptr, nullptr, nullptr},
|
||||||
|
{"__nb_signature__", gc_func_get_sig, nullptr, nullptr, nullptr},
|
||||||
|
{nullptr, nullptr, nullptr, nullptr, nullptr}};
|
||||||
|
|
||||||
|
static PyObject* gc_func_getattro(PyObject* self, PyObject* name_) {
|
||||||
|
gc_func* w = (gc_func*)self;
|
||||||
|
auto f = PyCFunction(PyType_GetSlot(Py_TYPE(w->func), Py_tp_getattro));
|
||||||
|
return f(w->func, name_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Table of custom type slots we want to install
|
||||||
|
PyType_Slot gc_func_slots[] = {
|
||||||
|
{Py_tp_traverse, (void*)gc_func_tp_traverse},
|
||||||
|
{Py_tp_clear, (void*)gc_func_tp_clear},
|
||||||
|
{Py_tp_getset, (void*)gc_func_getset},
|
||||||
|
{Py_tp_getattro, (void*)gc_func_getattro},
|
||||||
|
{Py_tp_members, (void*)gc_func_members},
|
||||||
|
{Py_tp_call, (void*)PyVectorcall_Call},
|
||||||
|
{Py_tp_dealloc, (void*)gc_func_dealloc},
|
||||||
|
{0, 0}};
|
||||||
|
|
||||||
|
static PyType_Spec gc_func_spec = {
|
||||||
|
/* .name = */ "mlx.gc_func",
|
||||||
|
/* .basicsize = */ (int)sizeof(gc_func),
|
||||||
|
/* .itemsize = */ 0,
|
||||||
|
/* .flags = */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | NB_HAVE_VECTORCALL,
|
||||||
|
/* .slots = */ gc_func_slots};
|
||||||
|
|
||||||
|
static PyTypeObject* gc_func_tp = nullptr;
|
||||||
|
|
||||||
|
nb::callable mlx_func(nb::object func, std::vector<PyObject*> deps) {
|
||||||
|
gc_func* r = (gc_func*)PyType_GenericAlloc(gc_func_tp, 0);
|
||||||
|
r->func = func.inc_ref().ptr();
|
||||||
|
r->deps = std::move(deps);
|
||||||
|
r->vectorcall = gc_func_vectorcall;
|
||||||
|
return nb::steal<nb::callable>((PyObject*)r);
|
||||||
|
}
|
||||||
|
|
||||||
|
void init_mlx_func(nb::module_& m) {
|
||||||
|
gc_func_tp = (PyTypeObject*)PyType_FromSpec(&gc_func_spec);
|
||||||
|
if (!gc_func_tp) {
|
||||||
|
nb::raise("Could not register MLX function type.");
|
||||||
|
}
|
||||||
|
}
|
24
python/src/mlx_func.h
Normal file
24
python/src/mlx_func.h
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <nanobind/nanobind.h>
|
||||||
|
#include <nanobind/stl/function.h>
|
||||||
|
|
||||||
|
namespace nb = nanobind;
|
||||||
|
using namespace nb::literals;
|
||||||
|
|
||||||
|
nb::callable mlx_func(nb::object func, std::vector<PyObject*> deps);
|
||||||
|
|
||||||
|
template <typename F, typename... Deps>
|
||||||
|
nb::callable mlx_func(F func, Deps&&... deps) {
|
||||||
|
return mlx_func(
|
||||||
|
nb::cpp_function(std::move(func)), std::vector<PyObject*>{deps.ptr()...});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... Deps>
|
||||||
|
nb::callable mlx_func(nb::object func, Deps&&... deps) {
|
||||||
|
return mlx_func(std::move(func), std::vector<PyObject*>{deps.ptr()...});
|
||||||
|
}
|
@ -19,6 +19,7 @@
|
|||||||
#include "mlx/transforms.h"
|
#include "mlx/transforms.h"
|
||||||
#include "mlx/transforms_impl.h"
|
#include "mlx/transforms_impl.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
#include "python/src/mlx_func.h"
|
||||||
#include "python/src/trees.h"
|
#include "python/src/trees.h"
|
||||||
|
|
||||||
namespace mx = mlx::core;
|
namespace mx = mlx::core;
|
||||||
@ -543,9 +544,9 @@ struct PyCompiledFun {
|
|||||||
|
|
||||||
tree_cache().erase(fun_id);
|
tree_cache().erase(fun_id);
|
||||||
mx::detail::compile_erase(fun_id);
|
mx::detail::compile_erase(fun_id);
|
||||||
fun.release().dec_ref();
|
fun.reset();
|
||||||
captured_inputs.release().dec_ref();
|
captured_inputs.reset();
|
||||||
captured_outputs.release().dec_ref();
|
captured_outputs.reset();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -555,7 +556,7 @@ class PyCheckpointedFun {
|
|||||||
~PyCheckpointedFun() {
|
~PyCheckpointedFun() {
|
||||||
nb::gil_scoped_acquire gil;
|
nb::gil_scoped_acquire gil;
|
||||||
|
|
||||||
fun_.release().dec_ref();
|
fun_.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
struct InnerFunction {
|
struct InnerFunction {
|
||||||
@ -573,8 +574,8 @@ class PyCheckpointedFun {
|
|||||||
~InnerFunction() {
|
~InnerFunction() {
|
||||||
nb::gil_scoped_acquire gil;
|
nb::gil_scoped_acquire gil;
|
||||||
|
|
||||||
fun_.release().dec_ref();
|
fun_.reset();
|
||||||
args_structure_.release().dec_ref();
|
args_structure_.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<mx::array> operator()(const std::vector<mx::array>& inputs) {
|
std::vector<mx::array> operator()(const std::vector<mx::array>& inputs) {
|
||||||
@ -609,6 +610,10 @@ class PyCheckpointedFun {
|
|||||||
nb::callable fun_;
|
nb::callable fun_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg);
|
||||||
|
|
||||||
|
int py_custom_function_tp_clear(PyObject* self);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* PyCustomFunction is the class that implements the python decorator
|
* PyCustomFunction is the class that implements the python decorator
|
||||||
* `mx.custom_function`.
|
* `mx.custom_function`.
|
||||||
@ -641,17 +646,7 @@ class PyCustomFunction {
|
|||||||
PyCustomFunction(nb::callable fun) : fun_(std::move(fun)) {}
|
PyCustomFunction(nb::callable fun) : fun_(std::move(fun)) {}
|
||||||
~PyCustomFunction() {
|
~PyCustomFunction() {
|
||||||
nb::gil_scoped_acquire gil;
|
nb::gil_scoped_acquire gil;
|
||||||
|
reset();
|
||||||
fun_.release().dec_ref();
|
|
||||||
if (vjp_fun_.has_value()) {
|
|
||||||
(*vjp_fun_).release().dec_ref();
|
|
||||||
}
|
|
||||||
if (jvp_fun_.has_value()) {
|
|
||||||
(*jvp_fun_).release().dec_ref();
|
|
||||||
}
|
|
||||||
if (vmap_fun_.has_value()) {
|
|
||||||
(*vmap_fun_).release().dec_ref();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct InnerFunction {
|
struct InnerFunction {
|
||||||
@ -669,10 +664,10 @@ class PyCustomFunction {
|
|||||||
~InnerFunction() {
|
~InnerFunction() {
|
||||||
nb::gil_scoped_acquire gil;
|
nb::gil_scoped_acquire gil;
|
||||||
|
|
||||||
fun_.release().dec_ref();
|
fun_.reset();
|
||||||
input_structure_.release().dec_ref();
|
input_structure_.reset();
|
||||||
if (output_structure_.use_count() == 1) {
|
if (output_structure_.use_count() == 1) {
|
||||||
output_structure_->release().dec_ref();
|
output_structure_->reset();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -703,10 +698,10 @@ class PyCustomFunction {
|
|||||||
~InnerVJPFunction() {
|
~InnerVJPFunction() {
|
||||||
nb::gil_scoped_acquire gil;
|
nb::gil_scoped_acquire gil;
|
||||||
|
|
||||||
vjp_fun_.release().dec_ref();
|
vjp_fun_.reset();
|
||||||
input_structure_.release().dec_ref();
|
input_structure_.reset();
|
||||||
if (output_structure_.use_count() == 1) {
|
if (output_structure_.use_count() == 1) {
|
||||||
output_structure_->release().dec_ref();
|
output_structure_->reset();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -746,8 +741,8 @@ class PyCustomFunction {
|
|||||||
~InnerJVPFunction() {
|
~InnerJVPFunction() {
|
||||||
nb::gil_scoped_acquire gil;
|
nb::gil_scoped_acquire gil;
|
||||||
|
|
||||||
jvp_fun_.release().dec_ref();
|
jvp_fun_.reset();
|
||||||
input_structure_.release().dec_ref();
|
input_structure_.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<mx::array> operator()(
|
std::vector<mx::array> operator()(
|
||||||
@ -801,8 +796,8 @@ class PyCustomFunction {
|
|||||||
~InnerVmapFunction() {
|
~InnerVmapFunction() {
|
||||||
nb::gil_scoped_acquire gil;
|
nb::gil_scoped_acquire gil;
|
||||||
|
|
||||||
vmap_fun_.release().dec_ref();
|
vmap_fun_.reset();
|
||||||
input_structure_.release().dec_ref();
|
input_structure_.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<mx::array>, std::vector<int>> operator()(
|
std::pair<std::vector<mx::array>, std::vector<int>> operator()(
|
||||||
@ -904,6 +899,20 @@ class PyCustomFunction {
|
|||||||
vmap_fun_ = vmap_fun;
|
vmap_fun_ = vmap_fun;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
void reset() {
|
||||||
|
fun_.reset();
|
||||||
|
if (vjp_fun_.has_value()) {
|
||||||
|
(*vjp_fun_).reset();
|
||||||
|
}
|
||||||
|
if (jvp_fun_.has_value()) {
|
||||||
|
(*jvp_fun_).reset();
|
||||||
|
}
|
||||||
|
if (vmap_fun_.has_value()) {
|
||||||
|
(*vmap_fun_).reset();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
friend int py_custom_function_tp_traverse(PyObject*, visitproc, void*);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::optional<InnerVJPFunction> make_vjp_function(
|
std::optional<InnerVJPFunction> make_vjp_function(
|
||||||
@ -940,10 +949,40 @@ class PyCustomFunction {
|
|||||||
std::optional<nb::callable> vmap_fun_;
|
std::optional<nb::callable> vmap_fun_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg) {
|
||||||
|
auto* p = nb::inst_ptr<PyCustomFunction>(self);
|
||||||
|
nb::handle v = nb::find(p->fun_);
|
||||||
|
Py_VISIT(v.ptr());
|
||||||
|
if (p->vjp_fun_.has_value()) {
|
||||||
|
nb::handle v = nb::find(*(p->vjp_fun_));
|
||||||
|
Py_VISIT(v.ptr());
|
||||||
|
}
|
||||||
|
if (p->jvp_fun_.has_value()) {
|
||||||
|
nb::handle v = nb::find(*(p->jvp_fun_));
|
||||||
|
Py_VISIT(v.ptr());
|
||||||
|
}
|
||||||
|
if (p->vmap_fun_.has_value()) {
|
||||||
|
nb::handle v = nb::find(*(p->vmap_fun_));
|
||||||
|
Py_VISIT(v.ptr());
|
||||||
|
}
|
||||||
|
Py_VISIT(Py_TYPE(self));
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
int py_custom_function_tp_clear(PyObject* self) {
|
||||||
|
auto* p = nb::inst_ptr<PyCustomFunction>(self);
|
||||||
|
p->reset();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
PyType_Slot py_custom_function_slots[] = {
|
||||||
|
{Py_tp_traverse, (void*)py_custom_function_tp_traverse},
|
||||||
|
{Py_tp_clear, (void*)py_custom_function_tp_clear},
|
||||||
|
{0, 0}};
|
||||||
|
|
||||||
void init_transforms(nb::module_& m) {
|
void init_transforms(nb::module_& m) {
|
||||||
nb::class_<PyCustomFunction>(
|
nb::class_<PyCustomFunction>(
|
||||||
m,
|
m,
|
||||||
"custom_function",
|
"custom_function",
|
||||||
|
nb::type_slots(py_custom_function_slots),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Set up a function for custom gradient and vmap definitions.
|
Set up a function for custom gradient and vmap definitions.
|
||||||
|
|
||||||
@ -1224,8 +1263,10 @@ void init_transforms(nb::module_& m) {
|
|||||||
const StrOrSet& argnames) {
|
const StrOrSet& argnames) {
|
||||||
auto [argnums_vec, argnames_set] =
|
auto [argnums_vec, argnames_set] =
|
||||||
validate_argnums_argnames(argnums, argnames);
|
validate_argnums_argnames(argnums, argnames);
|
||||||
return nb::cpp_function(py_value_and_grad(
|
return mlx_func(
|
||||||
fun, argnums_vec, argnames_set, "[value_and_grad]", false));
|
py_value_and_grad(
|
||||||
|
fun, argnums_vec, argnames_set, "[value_and_grad]", false),
|
||||||
|
fun);
|
||||||
},
|
},
|
||||||
"fun"_a,
|
"fun"_a,
|
||||||
"argnums"_a = nb::none(),
|
"argnums"_a = nb::none(),
|
||||||
@ -1290,9 +1331,11 @@ void init_transforms(nb::module_& m) {
|
|||||||
validate_argnums_argnames(argnums, argnames);
|
validate_argnums_argnames(argnums, argnames);
|
||||||
auto fn =
|
auto fn =
|
||||||
py_value_and_grad(fun, argnums_vec, argnames_set, "[grad]", true);
|
py_value_and_grad(fun, argnums_vec, argnames_set, "[grad]", true);
|
||||||
return nb::cpp_function([fn](nb::args& args, nb::kwargs& kwargs) {
|
return mlx_func(
|
||||||
|
[fn = std::move(fn)](nb::args& args, nb::kwargs& kwargs) {
|
||||||
return fn(args, kwargs).second;
|
return fn(args, kwargs).second;
|
||||||
});
|
},
|
||||||
|
fun);
|
||||||
},
|
},
|
||||||
"fun"_a,
|
"fun"_a,
|
||||||
"argnums"_a = nb::none(),
|
"argnums"_a = nb::none(),
|
||||||
@ -1324,7 +1367,8 @@ void init_transforms(nb::module_& m) {
|
|||||||
[](const nb::callable& fun,
|
[](const nb::callable& fun,
|
||||||
const nb::object& in_axes,
|
const nb::object& in_axes,
|
||||||
const nb::object& out_axes) {
|
const nb::object& out_axes) {
|
||||||
return nb::cpp_function(py_vmap(fun, in_axes, out_axes));
|
return mlx_func(
|
||||||
|
py_vmap(fun, in_axes, out_axes), fun, in_axes, out_axes);
|
||||||
},
|
},
|
||||||
"fun"_a,
|
"fun"_a,
|
||||||
"in_axes"_a = 0,
|
"in_axes"_a = 0,
|
||||||
@ -1379,11 +1423,15 @@ void init_transforms(nb::module_& m) {
|
|||||||
d.is_none() ? "MLX compiled function." : nb::cast<std::string>(d);
|
d.is_none() ? "MLX compiled function." : nb::cast<std::string>(d);
|
||||||
|
|
||||||
auto sig_str = sig.str();
|
auto sig_str = sig.str();
|
||||||
return nb::cpp_function(
|
return mlx_func(
|
||||||
|
nb::cpp_function(
|
||||||
PyCompiledFun{fun, inputs, outputs, shapeless},
|
PyCompiledFun{fun, inputs, outputs, shapeless},
|
||||||
nb::name(name.c_str()),
|
nb::name(name.c_str()),
|
||||||
nb::sig(sig_str.c_str()),
|
nb::sig(sig_str.c_str()),
|
||||||
doc.c_str());
|
doc.c_str()),
|
||||||
|
fun,
|
||||||
|
inputs,
|
||||||
|
outputs);
|
||||||
},
|
},
|
||||||
"fun"_a,
|
"fun"_a,
|
||||||
"inputs"_a = nb::none(),
|
"inputs"_a = nb::none(),
|
||||||
@ -1435,7 +1483,7 @@ void init_transforms(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"checkpoint",
|
"checkpoint",
|
||||||
[](nb::callable fun) { return nb::cpp_function(PyCheckpointedFun{fun}); },
|
[](nb::callable fun) { return mlx_func(PyCheckpointedFun{fun}, fun); },
|
||||||
"fun"_a);
|
"fun"_a);
|
||||||
|
|
||||||
// Register static Python object cleanup before the interpreter exits
|
// Register static Python object cleanup before the interpreter exits
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import gc
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@ -737,6 +738,38 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
|||||||
expected[4:-5:-2] = tan_b
|
expected[4:-5:-2] = tan_b
|
||||||
self.assertTrue(mx.allclose(grad, expected))
|
self.assertTrue(mx.allclose(grad, expected))
|
||||||
|
|
||||||
|
def test_leaks(self):
|
||||||
|
for transform in [
|
||||||
|
mx.grad,
|
||||||
|
mx.value_and_grad,
|
||||||
|
mx.custom_function,
|
||||||
|
mx.checkpoint,
|
||||||
|
]:
|
||||||
|
if mx.metal.is_available():
|
||||||
|
mem_pre = mx.metal.get_active_memory()
|
||||||
|
else:
|
||||||
|
mem_pre = 0
|
||||||
|
|
||||||
|
def outer():
|
||||||
|
d = {}
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
return d["x"]
|
||||||
|
|
||||||
|
d["f"] = transform(f)
|
||||||
|
d["x"] = mx.array([0] * 1000)
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
outer()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
if mx.metal.is_available():
|
||||||
|
mem_post = mx.metal.get_active_memory()
|
||||||
|
else:
|
||||||
|
mem_post = 0
|
||||||
|
|
||||||
|
self.assertEqual(mem_pre, mem_post)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import gc
|
||||||
import io
|
import io
|
||||||
import unittest
|
import unittest
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -926,6 +927,32 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(out[0].shape, (3, 1, 4, 2))
|
self.assertEqual(out[0].shape, (3, 1, 4, 2))
|
||||||
self.assertEqual(out[1].shape, (2, 2, 5))
|
self.assertEqual(out[1].shape, (2, 2, 5))
|
||||||
|
|
||||||
|
def test_leaks(self):
|
||||||
|
if mx.metal.is_available():
|
||||||
|
mem_pre = mx.metal.get_active_memory()
|
||||||
|
else:
|
||||||
|
mem_pre = 0
|
||||||
|
|
||||||
|
def outer():
|
||||||
|
d = {}
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
return d["x"]
|
||||||
|
|
||||||
|
d["f"] = mx.compile(f)
|
||||||
|
d["x"] = mx.array([0] * 1000)
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
outer()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
if mx.metal.is_available():
|
||||||
|
mem_post = mx.metal.get_active_memory()
|
||||||
|
else:
|
||||||
|
mem_post = 0
|
||||||
|
|
||||||
|
self.assertEqual(mem_pre, mem_post)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import gc
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
@ -239,6 +240,33 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
|||||||
constants_size = constant.nbytes + 8192
|
constants_size = constant.nbytes + 8192
|
||||||
self.assertTrue(os.path.getsize(path) < constants_size)
|
self.assertTrue(os.path.getsize(path) < constants_size)
|
||||||
|
|
||||||
|
def test_leaks(self):
|
||||||
|
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||||
|
if mx.metal.is_available():
|
||||||
|
mem_pre = mx.metal.get_active_memory()
|
||||||
|
else:
|
||||||
|
mem_pre = 0
|
||||||
|
|
||||||
|
def outer():
|
||||||
|
d = {}
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
return d["x"]
|
||||||
|
|
||||||
|
d["f"] = mx.exporter(path, f)
|
||||||
|
d["x"] = mx.array([0] * 1000)
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
outer()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
if mx.metal.is_available():
|
||||||
|
mem_post = mx.metal.get_active_memory()
|
||||||
|
else:
|
||||||
|
mem_post = 0
|
||||||
|
|
||||||
|
self.assertEqual(mem_pre, mem_post)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import gc
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@ -608,6 +609,32 @@ class TestVmap(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(fx.shape, (5, 6, 7))
|
self.assertEqual(fx.shape, (5, 6, 7))
|
||||||
self.assertEqual(fy.shape, (4, 5, 6, 7))
|
self.assertEqual(fy.shape, (4, 5, 6, 7))
|
||||||
|
|
||||||
|
def test_leaks(self):
|
||||||
|
if mx.metal.is_available():
|
||||||
|
mem_pre = mx.metal.get_active_memory()
|
||||||
|
else:
|
||||||
|
mem_pre = 0
|
||||||
|
|
||||||
|
def outer():
|
||||||
|
d = {}
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
return d["x"]
|
||||||
|
|
||||||
|
d["f"] = mx.vmap(f)
|
||||||
|
d["x"] = mx.array([0] * 1000)
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
outer()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
if mx.metal.is_available():
|
||||||
|
mem_post = mx.metal.get_active_memory()
|
||||||
|
else:
|
||||||
|
mem_post = 0
|
||||||
|
|
||||||
|
self.assertEqual(mem_pre, mem_post)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user