mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
fix multi-output with compile
This commit is contained in:
parent
ed4d867092
commit
390e50b163
@ -15,7 +15,7 @@ namespace detail {
|
|||||||
|
|
||||||
bool& compiler_disabled() {
|
bool& compiler_disabled() {
|
||||||
auto get_val = []() {
|
auto get_val = []() {
|
||||||
if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILER")) {
|
if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILE")) {
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
@ -343,13 +343,30 @@ std::vector<array> compile_replace(
|
|||||||
if (!a.has_primitive()) {
|
if (!a.has_primitive()) {
|
||||||
trace_to_real.insert({a.id(), a});
|
trace_to_real.insert({a.id(), a});
|
||||||
} else {
|
} else {
|
||||||
|
// Find real inputs
|
||||||
std::vector<array> real_inputs;
|
std::vector<array> real_inputs;
|
||||||
for (auto& in : a.inputs()) {
|
for (auto& in : a.inputs()) {
|
||||||
real_inputs.push_back(trace_to_real.at(in.id()));
|
real_inputs.push_back(trace_to_real.at(in.id()));
|
||||||
}
|
}
|
||||||
|
if (a.siblings().empty()) {
|
||||||
auto real_a = array(
|
auto real_a = array(
|
||||||
a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs));
|
a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs));
|
||||||
trace_to_real.insert({a.id(), std::move(real_a)});
|
trace_to_real.insert({a.id(), std::move(real_a)});
|
||||||
|
} else {
|
||||||
|
// Ensure the order is correct for multi-output primitives
|
||||||
|
std::vector<std::vector<int>> shapes;
|
||||||
|
std::vector<Dtype> types;
|
||||||
|
auto trace_out = a.outputs();
|
||||||
|
for (auto& o : trace_out) {
|
||||||
|
shapes.push_back(o.shape());
|
||||||
|
types.push_back(o.dtype());
|
||||||
|
}
|
||||||
|
auto real_out =
|
||||||
|
array::make_arrays(shapes, types, a.primitive_ptr(), real_inputs);
|
||||||
|
for (int i = 0; i < trace_out.size(); ++i) {
|
||||||
|
trace_to_real.insert({trace_out[i].id(), std::move(real_out[i])});
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -412,11 +429,11 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
|||||||
return detail::compile(fun, fun_id);
|
return detail::compile(fun, fun_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
void disable_compiler() {
|
void disable_compile() {
|
||||||
detail::compiler_disabled() = true;
|
detail::compiler_disabled() = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void enable_compiler() {
|
void enable_compile() {
|
||||||
detail::compiler_disabled() = false;
|
detail::compiler_disabled() = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11,15 +11,15 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
|||||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun);
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun);
|
||||||
|
|
||||||
/** Globally disable compilation.
|
/** Globally disable compilation.
|
||||||
* Setting the environment variable ``MLX_DISABLE_COMPILER`` can also
|
* Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
|
||||||
* be used to disable compilation.
|
* be used to disable compilation.
|
||||||
*/
|
*/
|
||||||
void disable_compiler();
|
void disable_compile();
|
||||||
|
|
||||||
/** Globally enable compilation.
|
/** Globally enable compilation.
|
||||||
* This will override the environment variable ``MLX_DISABLE_COMPILER``.
|
* This will override the environment variable ``MLX_DISABLE_COMPILE``.
|
||||||
*/
|
*/
|
||||||
void enable_compiler();
|
void enable_compile();
|
||||||
|
|
||||||
void eval(const std::vector<array>& outputs);
|
void eval(const std::vector<array>& outputs);
|
||||||
|
|
||||||
|
@ -816,22 +816,22 @@ void init_transforms(py::module_& m) {
|
|||||||
as ``fun`` and returns the the same output(s).
|
as ``fun`` and returns the the same output(s).
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"disable_compiler",
|
"disable_compile",
|
||||||
&disable_compiler,
|
&disable_compile,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
disable_compiler() -> None
|
disable_compile() -> None
|
||||||
|
|
||||||
Globally disable compilation. Setting the environment variable
|
Globally disable compilation. Setting the environment variable
|
||||||
``MLX_DISABLE_COMPILER`` can also be used to disable compilation.
|
``MLX_DISABLE_COMPILE`` can also be used to disable compilation.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"enable_compiler",
|
"enable_compile",
|
||||||
&enable_compiler,
|
&enable_compile,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
enable_compiler() -> None
|
enable_compiler() -> None
|
||||||
|
|
||||||
Globally enable compilation. This will override the environment
|
Globally enable compilation. This will override the environment
|
||||||
variable ``MLX_DISABLE_COMPILER`` if set.
|
variable ``MLX_DISABLE_COMPILE`` if set.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
|
||||||
// Register static Python object cleanup before the interpreter exits
|
// Register static Python object cleanup before the interpreter exits
|
||||||
|
@ -164,12 +164,12 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
n_compiled = count_prims(cfun(x))
|
n_compiled = count_prims(cfun(x))
|
||||||
|
|
||||||
# Check disabled
|
# Check disabled
|
||||||
mx.disable_compiler()
|
mx.disable_compile()
|
||||||
n_uncompiled = count_prims(cfun(x))
|
n_uncompiled = count_prims(cfun(x))
|
||||||
self.assertTrue(n_compiled < n_uncompiled)
|
self.assertTrue(n_compiled < n_uncompiled)
|
||||||
|
|
||||||
# Check renabled
|
# Check renabled
|
||||||
mx.enable_compiler()
|
mx.enable_compile()
|
||||||
n_enable_compiled = count_prims(cfun(x))
|
n_enable_compiled = count_prims(cfun(x))
|
||||||
self.assertEqual(n_compiled, n_enable_compiled)
|
self.assertEqual(n_compiled, n_enable_compiled)
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
#include "doctest/doctest.h"
|
#include "doctest/doctest.h"
|
||||||
|
|
||||||
#include "mlx/mlx.h"
|
#include "mlx/mlx.h"
|
||||||
@ -99,9 +100,9 @@ TEST_CASE("test nested compile") {
|
|||||||
|
|
||||||
TEST_CASE("test enable and disable compile") {
|
TEST_CASE("test enable and disable compile") {
|
||||||
CHECK_THROWS(compile(nullptr));
|
CHECK_THROWS(compile(nullptr));
|
||||||
disable_compiler();
|
disable_compile();
|
||||||
compile(nullptr);
|
compile(nullptr);
|
||||||
enable_compiler();
|
enable_compile();
|
||||||
CHECK_THROWS(compile(nullptr));
|
CHECK_THROWS(compile(nullptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user