fix multi-output with compile

This commit is contained in:
Awni Hannun 2024-01-17 06:55:11 -08:00
parent ed4d867092
commit 390e50b163
5 changed files with 39 additions and 21 deletions

View File

@ -15,7 +15,7 @@ namespace detail {
bool& compiler_disabled() {
auto get_val = []() {
if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILER")) {
if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILE")) {
return true;
} else {
return false;
@ -343,13 +343,30 @@ std::vector<array> compile_replace(
if (!a.has_primitive()) {
trace_to_real.insert({a.id(), a});
} else {
// Find real inputs
std::vector<array> real_inputs;
for (auto& in : a.inputs()) {
real_inputs.push_back(trace_to_real.at(in.id()));
}
auto real_a = array(
a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs));
trace_to_real.insert({a.id(), std::move(real_a)});
if (a.siblings().empty()) {
auto real_a = array(
a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs));
trace_to_real.insert({a.id(), std::move(real_a)});
} else {
// Ensure the order is correct for multi-output primitives
std::vector<std::vector<int>> shapes;
std::vector<Dtype> types;
auto trace_out = a.outputs();
for (auto& o : trace_out) {
shapes.push_back(o.shape());
types.push_back(o.dtype());
}
auto real_out =
array::make_arrays(shapes, types, a.primitive_ptr(), real_inputs);
for (int i = 0; i < trace_out.size(); ++i) {
trace_to_real.insert({trace_out[i].id(), std::move(real_out[i])});
}
}
}
}
@ -412,11 +429,11 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
return detail::compile(fun, fun_id);
}
void disable_compiler() {
void disable_compile() {
detail::compiler_disabled() = true;
}
void enable_compiler() {
void enable_compile() {
detail::compiler_disabled() = false;
}

View File

@ -11,15 +11,15 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun);
/** Globally disable compilation.
* Setting the environment variable ``MLX_DISABLE_COMPILER`` can also
* Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
* be used to disable compilation.
*/
void disable_compiler();
void disable_compile();
/** Globally enable compilation.
* This will override the environment variable ``MLX_DISABLE_COMPILER``.
* This will override the environment variable ``MLX_DISABLE_COMPILE``.
*/
void enable_compiler();
void enable_compile();
void eval(const std::vector<array>& outputs);

View File

@ -816,22 +816,22 @@ void init_transforms(py::module_& m) {
as ``fun`` and returns the the same output(s).
)pbdoc");
m.def(
"disable_compiler",
&disable_compiler,
"disable_compile",
&disable_compile,
R"pbdoc(
disable_compiler() -> None
disable_compile() -> None
Globally disable compilation. Setting the environment variable
``MLX_DISABLE_COMPILER`` can also be used to disable compilation.
``MLX_DISABLE_COMPILE`` can also be used to disable compilation.
)pbdoc");
m.def(
"enable_compiler",
&enable_compiler,
"enable_compile",
&enable_compile,
R"pbdoc(
enable_compiler() -> None
Globally enable compilation. This will override the environment
variable ``MLX_DISABLE_COMPILER`` if set.
variable ``MLX_DISABLE_COMPILE`` if set.
)pbdoc");
// Register static Python object cleanup before the interpreter exits

View File

@ -164,12 +164,12 @@ class TestCompile(mlx_tests.MLXTestCase):
n_compiled = count_prims(cfun(x))
# Check disabled
mx.disable_compiler()
mx.disable_compile()
n_uncompiled = count_prims(cfun(x))
self.assertTrue(n_compiled < n_uncompiled)
# Check renabled
mx.enable_compiler()
mx.enable_compile()
n_enable_compiled = count_prims(cfun(x))
self.assertEqual(n_compiled, n_enable_compiled)

View File

@ -1,5 +1,6 @@
// Copyright © 2023 Apple Inc.
#include <iostream>
#include "doctest/doctest.h"
#include "mlx/mlx.h"
@ -99,9 +100,9 @@ TEST_CASE("test nested compile") {
TEST_CASE("test enable and disable compile") {
CHECK_THROWS(compile(nullptr));
disable_compiler();
disable_compile();
compile(nullptr);
enable_compiler();
enable_compile();
CHECK_THROWS(compile(nullptr));
}