fix multi-output with compile

This commit is contained in:
Awni Hannun
2024-01-17 06:55:11 -08:00
parent ed4d867092
commit 390e50b163
5 changed files with 39 additions and 21 deletions

View File

@@ -816,22 +816,22 @@ void init_transforms(py::module_& m) {
as ``fun`` and returns the the same output(s).
)pbdoc");
m.def(
"disable_compiler",
&disable_compiler,
"disable_compile",
&disable_compile,
R"pbdoc(
disable_compiler() -> None
disable_compile() -> None
Globally disable compilation. Setting the environment variable
``MLX_DISABLE_COMPILER`` can also be used to disable compilation.
``MLX_DISABLE_COMPILE`` can also be used to disable compilation.
)pbdoc");
m.def(
"enable_compiler",
&enable_compiler,
"enable_compile",
&enable_compile,
R"pbdoc(
enable_compiler() -> None
Globally enable compilation. This will override the environment
variable ``MLX_DISABLE_COMPILER`` if set.
variable ``MLX_DISABLE_COMPILE`` if set.
)pbdoc");
// Register static Python object cleanup before the interpreter exits

View File

@@ -164,12 +164,12 @@ class TestCompile(mlx_tests.MLXTestCase):
n_compiled = count_prims(cfun(x))
# Check disabled
mx.disable_compiler()
mx.disable_compile()
n_uncompiled = count_prims(cfun(x))
self.assertTrue(n_compiled < n_uncompiled)
# Check renabled
mx.enable_compiler()
mx.enable_compile()
n_enable_compiled = count_prims(cfun(x))
self.assertEqual(n_compiled, n_enable_compiled)