mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Cycle leak break (#1856)
* detect and break leaks in custom function * detect and break leaks in custom function
This commit is contained in:
parent
142b77751d
commit
2a45056ba8
@ -17,6 +17,7 @@ nanobind_add_module(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/mlx_func.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
||||
|
@ -60,8 +60,56 @@ validate_and_extract_inputs(
|
||||
return {args_, kwargs_};
|
||||
}
|
||||
|
||||
auto wrap_export_function(const nb::callable& fun) {
|
||||
return [fun](
|
||||
int py_function_exporter_tp_traverse(
|
||||
PyObject* self,
|
||||
visitproc visit,
|
||||
void* arg);
|
||||
|
||||
class PyFunctionExporter {
|
||||
public:
|
||||
PyFunctionExporter(mx::FunctionExporter exporter, nb::handle dep)
|
||||
: exporter_(std::move(exporter)), dep_(dep) {}
|
||||
~PyFunctionExporter() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
}
|
||||
PyFunctionExporter(const PyFunctionExporter&) = delete;
|
||||
PyFunctionExporter& operator=(const PyFunctionExporter&) = delete;
|
||||
PyFunctionExporter& operator=(const PyFunctionExporter&&) = delete;
|
||||
PyFunctionExporter(PyFunctionExporter&& other)
|
||||
: exporter_(std::move(other.exporter_)), dep_(std::move(other.dep_)) {}
|
||||
|
||||
void close() {
|
||||
exporter_.close();
|
||||
}
|
||||
void operator()(
|
||||
const std::vector<mx::array>& args,
|
||||
const std::map<std::string, mx::array>& kwargs) {
|
||||
exporter_(args, kwargs);
|
||||
}
|
||||
|
||||
friend int py_function_exporter_tp_traverse(PyObject*, visitproc, void*);
|
||||
|
||||
private:
|
||||
mx::FunctionExporter exporter_;
|
||||
nb::handle dep_;
|
||||
};
|
||||
|
||||
int py_function_exporter_tp_traverse(
|
||||
PyObject* self,
|
||||
visitproc visit,
|
||||
void* arg) {
|
||||
auto* p = nb::inst_ptr<PyFunctionExporter>(self);
|
||||
Py_VISIT(p->dep_.ptr());
|
||||
Py_VISIT(Py_TYPE(self));
|
||||
return 0;
|
||||
}
|
||||
|
||||
PyType_Slot py_function_exporter_slots[] = {
|
||||
{Py_tp_traverse, (void*)py_function_exporter_tp_traverse},
|
||||
{0, 0}};
|
||||
|
||||
auto wrap_export_function(nb::callable fun) {
|
||||
return [fun = std::move(fun)](
|
||||
const std::vector<mx::array>& args_,
|
||||
const std::map<std::string, mx::array>& kwargs_) {
|
||||
auto kwargs = nb::dict();
|
||||
@ -173,21 +221,21 @@ void init_export(nb::module_& m) {
|
||||
>>> out = fn((a, b), {"x": x, "y": y}[0]
|
||||
)pbdoc");
|
||||
|
||||
nb::class_<mx::FunctionExporter>(
|
||||
nb::class_<PyFunctionExporter>(
|
||||
m,
|
||||
"FunctionExporter",
|
||||
nb::type_slots(py_function_exporter_slots),
|
||||
R"pbdoc(
|
||||
A context managing class for exporting multiple traces of the same
|
||||
function to a file.
|
||||
|
||||
Make an instance of this class by calling fun:`mx.exporter`.
|
||||
)pbdoc")
|
||||
.def("close", &mx::FunctionExporter::close)
|
||||
.def(
|
||||
"__enter__", [](mx::FunctionExporter& exporter) { return &exporter; })
|
||||
.def("close", &PyFunctionExporter::close)
|
||||
.def("__enter__", [](PyFunctionExporter& exporter) { return &exporter; })
|
||||
.def(
|
||||
"__exit__",
|
||||
[](mx::FunctionExporter& exporter,
|
||||
[](PyFunctionExporter& exporter,
|
||||
const std::optional<nb::object>&,
|
||||
const std::optional<nb::object>&,
|
||||
const std::optional<nb::object>&) { exporter.close(); },
|
||||
@ -196,7 +244,7 @@ void init_export(nb::module_& m) {
|
||||
"traceback"_a = nb::none())
|
||||
.def(
|
||||
"__call__",
|
||||
[](mx::FunctionExporter& exporter,
|
||||
[](PyFunctionExporter& exporter,
|
||||
const nb::args& args,
|
||||
const nb::kwargs& kwargs) {
|
||||
auto [args_, kwargs_] =
|
||||
@ -206,8 +254,9 @@ void init_export(nb::module_& m) {
|
||||
|
||||
m.def(
|
||||
"exporter",
|
||||
[](const std::string& file, const nb::callable& fun, bool shapeless) {
|
||||
return mx::exporter(file, wrap_export_function(fun), shapeless);
|
||||
[](const std::string& file, nb::callable fun, bool shapeless) {
|
||||
return PyFunctionExporter{
|
||||
mx::exporter(file, wrap_export_function(fun), shapeless), fun};
|
||||
},
|
||||
"file"_a,
|
||||
"fun"_a,
|
||||
|
@ -7,6 +7,7 @@
|
||||
|
||||
namespace nb = nanobind;
|
||||
|
||||
void init_mlx_func(nb::module_&);
|
||||
void init_array(nb::module_&);
|
||||
void init_device(nb::module_&);
|
||||
void init_stream(nb::module_&);
|
||||
@ -28,6 +29,7 @@ NB_MODULE(core, m) {
|
||||
nb::module_::import_("mlx._os_warning");
|
||||
nb::set_leak_warnings(false);
|
||||
|
||||
init_mlx_func(m);
|
||||
init_device(m);
|
||||
init_stream(m);
|
||||
init_array(m);
|
||||
|
108
python/src/mlx_func.cpp
Normal file
108
python/src/mlx_func.cpp
Normal file
@ -0,0 +1,108 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "python/src/mlx_func.h"
|
||||
|
||||
// A garbage collected function which wraps nb::cpp_function
|
||||
// See https://github.com/wjakob/nanobind/discussions/919
|
||||
|
||||
struct gc_func {
|
||||
PyObject_HEAD
|
||||
// Vector call implementation that forwards calls to nanobind
|
||||
PyObject* (*vectorcall)(PyObject*, PyObject* const*, size_t, PyObject*);
|
||||
// The function itself
|
||||
PyObject* func;
|
||||
// A non-owning reference to dependencies owned by 'func'
|
||||
std::vector<PyObject*> deps;
|
||||
};
|
||||
|
||||
int gc_func_tp_traverse(PyObject* self, visitproc visit, void* arg) {
|
||||
gc_func* w = (gc_func*)self;
|
||||
Py_VISIT(w->func);
|
||||
for (auto d : w->deps) {
|
||||
Py_VISIT(d);
|
||||
}
|
||||
Py_VISIT(Py_TYPE(self));
|
||||
return 0;
|
||||
};
|
||||
|
||||
int gc_func_tp_clear(PyObject* self) {
|
||||
gc_func* w = (gc_func*)self;
|
||||
Py_CLEAR(w->func);
|
||||
return 0;
|
||||
}
|
||||
|
||||
PyObject* gc_func_get_doc(PyObject* self, void*) {
|
||||
return PyObject_GetAttrString(((gc_func*)self)->func, "__doc__");
|
||||
}
|
||||
|
||||
PyObject* gc_func_get_sig(PyObject* self, void*) {
|
||||
return PyObject_GetAttrString(((gc_func*)self)->func, "__nb_signature__");
|
||||
}
|
||||
|
||||
PyObject* gc_func_vectorcall(
|
||||
PyObject* self,
|
||||
PyObject* const* args,
|
||||
size_t nargs,
|
||||
PyObject* kwnames) {
|
||||
return PyObject_Vectorcall(((gc_func*)self)->func, args, nargs, kwnames);
|
||||
}
|
||||
|
||||
void gc_func_dealloc(PyObject* self) {
|
||||
PyObject_GC_UnTrack(self);
|
||||
Py_XDECREF(((gc_func*)self)->func);
|
||||
PyObject_GC_Del(self);
|
||||
}
|
||||
|
||||
static PyMemberDef gc_func_members[] = {
|
||||
{"__vectorcalloffset__",
|
||||
T_PYSSIZET,
|
||||
(Py_ssize_t)offsetof(gc_func, vectorcall),
|
||||
READONLY,
|
||||
nullptr},
|
||||
{nullptr, 0, 0, 0, nullptr}};
|
||||
|
||||
static PyGetSetDef gc_func_getset[] = {
|
||||
{"__doc__", gc_func_get_doc, nullptr, nullptr, nullptr},
|
||||
{"__nb_signature__", gc_func_get_sig, nullptr, nullptr, nullptr},
|
||||
{nullptr, nullptr, nullptr, nullptr, nullptr}};
|
||||
|
||||
static PyObject* gc_func_getattro(PyObject* self, PyObject* name_) {
|
||||
gc_func* w = (gc_func*)self;
|
||||
auto f = PyCFunction(PyType_GetSlot(Py_TYPE(w->func), Py_tp_getattro));
|
||||
return f(w->func, name_);
|
||||
}
|
||||
|
||||
// Table of custom type slots we want to install
|
||||
PyType_Slot gc_func_slots[] = {
|
||||
{Py_tp_traverse, (void*)gc_func_tp_traverse},
|
||||
{Py_tp_clear, (void*)gc_func_tp_clear},
|
||||
{Py_tp_getset, (void*)gc_func_getset},
|
||||
{Py_tp_getattro, (void*)gc_func_getattro},
|
||||
{Py_tp_members, (void*)gc_func_members},
|
||||
{Py_tp_call, (void*)PyVectorcall_Call},
|
||||
{Py_tp_dealloc, (void*)gc_func_dealloc},
|
||||
{0, 0}};
|
||||
|
||||
static PyType_Spec gc_func_spec = {
|
||||
/* .name = */ "mlx.gc_func",
|
||||
/* .basicsize = */ (int)sizeof(gc_func),
|
||||
/* .itemsize = */ 0,
|
||||
/* .flags = */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | NB_HAVE_VECTORCALL,
|
||||
/* .slots = */ gc_func_slots};
|
||||
|
||||
static PyTypeObject* gc_func_tp = nullptr;
|
||||
|
||||
nb::callable mlx_func(nb::object func, std::vector<PyObject*> deps) {
|
||||
gc_func* r = (gc_func*)PyType_GenericAlloc(gc_func_tp, 0);
|
||||
r->func = func.inc_ref().ptr();
|
||||
r->deps = std::move(deps);
|
||||
r->vectorcall = gc_func_vectorcall;
|
||||
return nb::steal<nb::callable>((PyObject*)r);
|
||||
}
|
||||
|
||||
void init_mlx_func(nb::module_& m) {
|
||||
gc_func_tp = (PyTypeObject*)PyType_FromSpec(&gc_func_spec);
|
||||
if (!gc_func_tp) {
|
||||
nb::raise("Could not register MLX function type.");
|
||||
}
|
||||
}
|
24
python/src/mlx_func.h
Normal file
24
python/src/mlx_func.h
Normal file
@ -0,0 +1,24 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/function.h>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
nb::callable mlx_func(nb::object func, std::vector<PyObject*> deps);
|
||||
|
||||
template <typename F, typename... Deps>
|
||||
nb::callable mlx_func(F func, Deps&&... deps) {
|
||||
return mlx_func(
|
||||
nb::cpp_function(std::move(func)), std::vector<PyObject*>{deps.ptr()...});
|
||||
}
|
||||
|
||||
template <typename... Deps>
|
||||
nb::callable mlx_func(nb::object func, Deps&&... deps) {
|
||||
return mlx_func(std::move(func), std::vector<PyObject*>{deps.ptr()...});
|
||||
}
|
@ -19,6 +19,7 @@
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
#include "mlx/utils.h"
|
||||
#include "python/src/mlx_func.h"
|
||||
#include "python/src/trees.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
@ -543,9 +544,9 @@ struct PyCompiledFun {
|
||||
|
||||
tree_cache().erase(fun_id);
|
||||
mx::detail::compile_erase(fun_id);
|
||||
fun.release().dec_ref();
|
||||
captured_inputs.release().dec_ref();
|
||||
captured_outputs.release().dec_ref();
|
||||
fun.reset();
|
||||
captured_inputs.reset();
|
||||
captured_outputs.reset();
|
||||
}
|
||||
};
|
||||
|
||||
@ -555,7 +556,7 @@ class PyCheckpointedFun {
|
||||
~PyCheckpointedFun() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
fun_.release().dec_ref();
|
||||
fun_.reset();
|
||||
}
|
||||
|
||||
struct InnerFunction {
|
||||
@ -573,8 +574,8 @@ class PyCheckpointedFun {
|
||||
~InnerFunction() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
fun_.release().dec_ref();
|
||||
args_structure_.release().dec_ref();
|
||||
fun_.reset();
|
||||
args_structure_.reset();
|
||||
}
|
||||
|
||||
std::vector<mx::array> operator()(const std::vector<mx::array>& inputs) {
|
||||
@ -609,6 +610,10 @@ class PyCheckpointedFun {
|
||||
nb::callable fun_;
|
||||
};
|
||||
|
||||
int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg);
|
||||
|
||||
int py_custom_function_tp_clear(PyObject* self);
|
||||
|
||||
/**
|
||||
* PyCustomFunction is the class that implements the python decorator
|
||||
* `mx.custom_function`.
|
||||
@ -641,17 +646,7 @@ class PyCustomFunction {
|
||||
PyCustomFunction(nb::callable fun) : fun_(std::move(fun)) {}
|
||||
~PyCustomFunction() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
fun_.release().dec_ref();
|
||||
if (vjp_fun_.has_value()) {
|
||||
(*vjp_fun_).release().dec_ref();
|
||||
}
|
||||
if (jvp_fun_.has_value()) {
|
||||
(*jvp_fun_).release().dec_ref();
|
||||
}
|
||||
if (vmap_fun_.has_value()) {
|
||||
(*vmap_fun_).release().dec_ref();
|
||||
}
|
||||
reset();
|
||||
}
|
||||
|
||||
struct InnerFunction {
|
||||
@ -669,10 +664,10 @@ class PyCustomFunction {
|
||||
~InnerFunction() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
fun_.release().dec_ref();
|
||||
input_structure_.release().dec_ref();
|
||||
fun_.reset();
|
||||
input_structure_.reset();
|
||||
if (output_structure_.use_count() == 1) {
|
||||
output_structure_->release().dec_ref();
|
||||
output_structure_->reset();
|
||||
}
|
||||
}
|
||||
|
||||
@ -703,10 +698,10 @@ class PyCustomFunction {
|
||||
~InnerVJPFunction() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
vjp_fun_.release().dec_ref();
|
||||
input_structure_.release().dec_ref();
|
||||
vjp_fun_.reset();
|
||||
input_structure_.reset();
|
||||
if (output_structure_.use_count() == 1) {
|
||||
output_structure_->release().dec_ref();
|
||||
output_structure_->reset();
|
||||
}
|
||||
}
|
||||
|
||||
@ -746,8 +741,8 @@ class PyCustomFunction {
|
||||
~InnerJVPFunction() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
jvp_fun_.release().dec_ref();
|
||||
input_structure_.release().dec_ref();
|
||||
jvp_fun_.reset();
|
||||
input_structure_.reset();
|
||||
}
|
||||
|
||||
std::vector<mx::array> operator()(
|
||||
@ -801,8 +796,8 @@ class PyCustomFunction {
|
||||
~InnerVmapFunction() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
vmap_fun_.release().dec_ref();
|
||||
input_structure_.release().dec_ref();
|
||||
vmap_fun_.reset();
|
||||
input_structure_.reset();
|
||||
}
|
||||
|
||||
std::pair<std::vector<mx::array>, std::vector<int>> operator()(
|
||||
@ -904,6 +899,20 @@ class PyCustomFunction {
|
||||
vmap_fun_ = vmap_fun;
|
||||
return *this;
|
||||
}
|
||||
void reset() {
|
||||
fun_.reset();
|
||||
if (vjp_fun_.has_value()) {
|
||||
(*vjp_fun_).reset();
|
||||
}
|
||||
if (jvp_fun_.has_value()) {
|
||||
(*jvp_fun_).reset();
|
||||
}
|
||||
if (vmap_fun_.has_value()) {
|
||||
(*vmap_fun_).reset();
|
||||
}
|
||||
}
|
||||
|
||||
friend int py_custom_function_tp_traverse(PyObject*, visitproc, void*);
|
||||
|
||||
private:
|
||||
std::optional<InnerVJPFunction> make_vjp_function(
|
||||
@ -940,10 +949,40 @@ class PyCustomFunction {
|
||||
std::optional<nb::callable> vmap_fun_;
|
||||
};
|
||||
|
||||
int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg) {
|
||||
auto* p = nb::inst_ptr<PyCustomFunction>(self);
|
||||
nb::handle v = nb::find(p->fun_);
|
||||
Py_VISIT(v.ptr());
|
||||
if (p->vjp_fun_.has_value()) {
|
||||
nb::handle v = nb::find(*(p->vjp_fun_));
|
||||
Py_VISIT(v.ptr());
|
||||
}
|
||||
if (p->jvp_fun_.has_value()) {
|
||||
nb::handle v = nb::find(*(p->jvp_fun_));
|
||||
Py_VISIT(v.ptr());
|
||||
}
|
||||
if (p->vmap_fun_.has_value()) {
|
||||
nb::handle v = nb::find(*(p->vmap_fun_));
|
||||
Py_VISIT(v.ptr());
|
||||
}
|
||||
Py_VISIT(Py_TYPE(self));
|
||||
return 0;
|
||||
}
|
||||
int py_custom_function_tp_clear(PyObject* self) {
|
||||
auto* p = nb::inst_ptr<PyCustomFunction>(self);
|
||||
p->reset();
|
||||
return 0;
|
||||
}
|
||||
PyType_Slot py_custom_function_slots[] = {
|
||||
{Py_tp_traverse, (void*)py_custom_function_tp_traverse},
|
||||
{Py_tp_clear, (void*)py_custom_function_tp_clear},
|
||||
{0, 0}};
|
||||
|
||||
void init_transforms(nb::module_& m) {
|
||||
nb::class_<PyCustomFunction>(
|
||||
m,
|
||||
"custom_function",
|
||||
nb::type_slots(py_custom_function_slots),
|
||||
R"pbdoc(
|
||||
Set up a function for custom gradient and vmap definitions.
|
||||
|
||||
@ -1224,8 +1263,10 @@ void init_transforms(nb::module_& m) {
|
||||
const StrOrSet& argnames) {
|
||||
auto [argnums_vec, argnames_set] =
|
||||
validate_argnums_argnames(argnums, argnames);
|
||||
return nb::cpp_function(py_value_and_grad(
|
||||
fun, argnums_vec, argnames_set, "[value_and_grad]", false));
|
||||
return mlx_func(
|
||||
py_value_and_grad(
|
||||
fun, argnums_vec, argnames_set, "[value_and_grad]", false),
|
||||
fun);
|
||||
},
|
||||
"fun"_a,
|
||||
"argnums"_a = nb::none(),
|
||||
@ -1290,9 +1331,11 @@ void init_transforms(nb::module_& m) {
|
||||
validate_argnums_argnames(argnums, argnames);
|
||||
auto fn =
|
||||
py_value_and_grad(fun, argnums_vec, argnames_set, "[grad]", true);
|
||||
return nb::cpp_function([fn](nb::args& args, nb::kwargs& kwargs) {
|
||||
return fn(args, kwargs).second;
|
||||
});
|
||||
return mlx_func(
|
||||
[fn = std::move(fn)](nb::args& args, nb::kwargs& kwargs) {
|
||||
return fn(args, kwargs).second;
|
||||
},
|
||||
fun);
|
||||
},
|
||||
"fun"_a,
|
||||
"argnums"_a = nb::none(),
|
||||
@ -1324,7 +1367,8 @@ void init_transforms(nb::module_& m) {
|
||||
[](const nb::callable& fun,
|
||||
const nb::object& in_axes,
|
||||
const nb::object& out_axes) {
|
||||
return nb::cpp_function(py_vmap(fun, in_axes, out_axes));
|
||||
return mlx_func(
|
||||
py_vmap(fun, in_axes, out_axes), fun, in_axes, out_axes);
|
||||
},
|
||||
"fun"_a,
|
||||
"in_axes"_a = 0,
|
||||
@ -1379,11 +1423,15 @@ void init_transforms(nb::module_& m) {
|
||||
d.is_none() ? "MLX compiled function." : nb::cast<std::string>(d);
|
||||
|
||||
auto sig_str = sig.str();
|
||||
return nb::cpp_function(
|
||||
PyCompiledFun{fun, inputs, outputs, shapeless},
|
||||
nb::name(name.c_str()),
|
||||
nb::sig(sig_str.c_str()),
|
||||
doc.c_str());
|
||||
return mlx_func(
|
||||
nb::cpp_function(
|
||||
PyCompiledFun{fun, inputs, outputs, shapeless},
|
||||
nb::name(name.c_str()),
|
||||
nb::sig(sig_str.c_str()),
|
||||
doc.c_str()),
|
||||
fun,
|
||||
inputs,
|
||||
outputs);
|
||||
},
|
||||
"fun"_a,
|
||||
"inputs"_a = nb::none(),
|
||||
@ -1435,7 +1483,7 @@ void init_transforms(nb::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"checkpoint",
|
||||
[](nb::callable fun) { return nb::cpp_function(PyCheckpointedFun{fun}); },
|
||||
[](nb::callable fun) { return mlx_func(PyCheckpointedFun{fun}, fun); },
|
||||
"fun"_a);
|
||||
|
||||
// Register static Python object cleanup before the interpreter exits
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
@ -737,6 +738,38 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
expected[4:-5:-2] = tan_b
|
||||
self.assertTrue(mx.allclose(grad, expected))
|
||||
|
||||
def test_leaks(self):
|
||||
for transform in [
|
||||
mx.grad,
|
||||
mx.value_and_grad,
|
||||
mx.custom_function,
|
||||
mx.checkpoint,
|
||||
]:
|
||||
if mx.metal.is_available():
|
||||
mem_pre = mx.metal.get_active_memory()
|
||||
else:
|
||||
mem_pre = 0
|
||||
|
||||
def outer():
|
||||
d = {}
|
||||
|
||||
def f(x):
|
||||
return d["x"]
|
||||
|
||||
d["f"] = transform(f)
|
||||
d["x"] = mx.array([0] * 1000)
|
||||
|
||||
for _ in range(5):
|
||||
outer()
|
||||
gc.collect()
|
||||
|
||||
if mx.metal.is_available():
|
||||
mem_post = mx.metal.get_active_memory()
|
||||
else:
|
||||
mem_post = 0
|
||||
|
||||
self.assertEqual(mem_pre, mem_post)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import gc
|
||||
import io
|
||||
import unittest
|
||||
from functools import partial
|
||||
@ -926,6 +927,32 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(out[0].shape, (3, 1, 4, 2))
|
||||
self.assertEqual(out[1].shape, (2, 2, 5))
|
||||
|
||||
def test_leaks(self):
|
||||
if mx.metal.is_available():
|
||||
mem_pre = mx.metal.get_active_memory()
|
||||
else:
|
||||
mem_pre = 0
|
||||
|
||||
def outer():
|
||||
d = {}
|
||||
|
||||
def f(x):
|
||||
return d["x"]
|
||||
|
||||
d["f"] = mx.compile(f)
|
||||
d["x"] = mx.array([0] * 1000)
|
||||
|
||||
for _ in range(5):
|
||||
outer()
|
||||
gc.collect()
|
||||
|
||||
if mx.metal.is_available():
|
||||
mem_post = mx.metal.get_active_memory()
|
||||
else:
|
||||
mem_post = 0
|
||||
|
||||
self.assertEqual(mem_pre, mem_post)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import gc
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
@ -239,6 +240,33 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
||||
constants_size = constant.nbytes + 8192
|
||||
self.assertTrue(os.path.getsize(path) < constants_size)
|
||||
|
||||
def test_leaks(self):
|
||||
path = os.path.join(self.test_dir, "fn.mlxfn")
|
||||
if mx.metal.is_available():
|
||||
mem_pre = mx.metal.get_active_memory()
|
||||
else:
|
||||
mem_pre = 0
|
||||
|
||||
def outer():
|
||||
d = {}
|
||||
|
||||
def f(x):
|
||||
return d["x"]
|
||||
|
||||
d["f"] = mx.exporter(path, f)
|
||||
d["x"] = mx.array([0] * 1000)
|
||||
|
||||
for _ in range(5):
|
||||
outer()
|
||||
gc.collect()
|
||||
|
||||
if mx.metal.is_available():
|
||||
mem_post = mx.metal.get_active_memory()
|
||||
else:
|
||||
mem_post = 0
|
||||
|
||||
self.assertEqual(mem_pre, mem_post)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
@ -608,6 +609,32 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(fx.shape, (5, 6, 7))
|
||||
self.assertEqual(fy.shape, (4, 5, 6, 7))
|
||||
|
||||
def test_leaks(self):
|
||||
if mx.metal.is_available():
|
||||
mem_pre = mx.metal.get_active_memory()
|
||||
else:
|
||||
mem_pre = 0
|
||||
|
||||
def outer():
|
||||
d = {}
|
||||
|
||||
def f(x):
|
||||
return d["x"]
|
||||
|
||||
d["f"] = mx.vmap(f)
|
||||
d["x"] = mx.array([0] * 1000)
|
||||
|
||||
for _ in range(5):
|
||||
outer()
|
||||
gc.collect()
|
||||
|
||||
if mx.metal.is_available():
|
||||
mem_post = mx.metal.get_active_memory()
|
||||
else:
|
||||
mem_post = 0
|
||||
|
||||
self.assertEqual(mem_pre, mem_post)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user