mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
fix multi-output with compile
This commit is contained in:
parent
ed4d867092
commit
390e50b163
@ -15,7 +15,7 @@ namespace detail {
|
||||
|
||||
bool& compiler_disabled() {
|
||||
auto get_val = []() {
|
||||
if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILER")) {
|
||||
if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILE")) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
@ -343,13 +343,30 @@ std::vector<array> compile_replace(
|
||||
if (!a.has_primitive()) {
|
||||
trace_to_real.insert({a.id(), a});
|
||||
} else {
|
||||
// Find real inputs
|
||||
std::vector<array> real_inputs;
|
||||
for (auto& in : a.inputs()) {
|
||||
real_inputs.push_back(trace_to_real.at(in.id()));
|
||||
}
|
||||
auto real_a = array(
|
||||
a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs));
|
||||
trace_to_real.insert({a.id(), std::move(real_a)});
|
||||
if (a.siblings().empty()) {
|
||||
auto real_a = array(
|
||||
a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs));
|
||||
trace_to_real.insert({a.id(), std::move(real_a)});
|
||||
} else {
|
||||
// Ensure the order is correct for multi-output primitives
|
||||
std::vector<std::vector<int>> shapes;
|
||||
std::vector<Dtype> types;
|
||||
auto trace_out = a.outputs();
|
||||
for (auto& o : trace_out) {
|
||||
shapes.push_back(o.shape());
|
||||
types.push_back(o.dtype());
|
||||
}
|
||||
auto real_out =
|
||||
array::make_arrays(shapes, types, a.primitive_ptr(), real_inputs);
|
||||
for (int i = 0; i < trace_out.size(); ++i) {
|
||||
trace_to_real.insert({trace_out[i].id(), std::move(real_out[i])});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -412,11 +429,11 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
return detail::compile(fun, fun_id);
|
||||
}
|
||||
|
||||
void disable_compiler() {
|
||||
void disable_compile() {
|
||||
detail::compiler_disabled() = true;
|
||||
}
|
||||
|
||||
void enable_compiler() {
|
||||
void enable_compile() {
|
||||
detail::compiler_disabled() = false;
|
||||
}
|
||||
|
||||
|
@ -11,15 +11,15 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun);
|
||||
|
||||
/** Globally disable compilation.
|
||||
* Setting the environment variable ``MLX_DISABLE_COMPILER`` can also
|
||||
* Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
|
||||
* be used to disable compilation.
|
||||
*/
|
||||
void disable_compiler();
|
||||
void disable_compile();
|
||||
|
||||
/** Globally enable compilation.
|
||||
* This will override the environment variable ``MLX_DISABLE_COMPILER``.
|
||||
* This will override the environment variable ``MLX_DISABLE_COMPILE``.
|
||||
*/
|
||||
void enable_compiler();
|
||||
void enable_compile();
|
||||
|
||||
void eval(const std::vector<array>& outputs);
|
||||
|
||||
|
@ -816,22 +816,22 @@ void init_transforms(py::module_& m) {
|
||||
as ``fun`` and returns the the same output(s).
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"disable_compiler",
|
||||
&disable_compiler,
|
||||
"disable_compile",
|
||||
&disable_compile,
|
||||
R"pbdoc(
|
||||
disable_compiler() -> None
|
||||
disable_compile() -> None
|
||||
|
||||
Globally disable compilation. Setting the environment variable
|
||||
``MLX_DISABLE_COMPILER`` can also be used to disable compilation.
|
||||
``MLX_DISABLE_COMPILE`` can also be used to disable compilation.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"enable_compiler",
|
||||
&enable_compiler,
|
||||
"enable_compile",
|
||||
&enable_compile,
|
||||
R"pbdoc(
|
||||
enable_compiler() -> None
|
||||
|
||||
Globally enable compilation. This will override the environment
|
||||
variable ``MLX_DISABLE_COMPILER`` if set.
|
||||
variable ``MLX_DISABLE_COMPILE`` if set.
|
||||
)pbdoc");
|
||||
|
||||
// Register static Python object cleanup before the interpreter exits
|
||||
|
@ -164,12 +164,12 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
n_compiled = count_prims(cfun(x))
|
||||
|
||||
# Check disabled
|
||||
mx.disable_compiler()
|
||||
mx.disable_compile()
|
||||
n_uncompiled = count_prims(cfun(x))
|
||||
self.assertTrue(n_compiled < n_uncompiled)
|
||||
|
||||
# Check renabled
|
||||
mx.enable_compiler()
|
||||
mx.enable_compile()
|
||||
n_enable_compiled = count_prims(cfun(x))
|
||||
self.assertEqual(n_compiled, n_enable_compiled)
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <iostream>
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
@ -99,9 +100,9 @@ TEST_CASE("test nested compile") {
|
||||
|
||||
TEST_CASE("test enable and disable compile") {
|
||||
CHECK_THROWS(compile(nullptr));
|
||||
disable_compiler();
|
||||
disable_compile();
|
||||
compile(nullptr);
|
||||
enable_compiler();
|
||||
enable_compile();
|
||||
CHECK_THROWS(compile(nullptr));
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user