mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compile primitive (#571)
* Compiled primitive with basic binary, unary graph-level fusion
This commit is contained in:
468
mlx/compile.cpp
468
mlx/compile.cpp
@@ -1,36 +1,198 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/compile.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace detail {
|
||||
constexpr int max_compile_depth = 6;
|
||||
|
||||
bool& compiler_disabled() {
|
||||
auto get_val = []() {
|
||||
if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILE")) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
static bool compiler_disabled_ = get_val();
|
||||
return compiler_disabled_;
|
||||
bool is_unary(const Primitive& p) {
|
||||
return (
|
||||
typeid(p) == typeid(Abs) || typeid(p) == typeid(ArcCos) ||
|
||||
typeid(p) == typeid(ArcCosh) || typeid(p) == typeid(ArcSin) ||
|
||||
typeid(p) == typeid(ArcSinh) || typeid(p) == typeid(ArcTan) ||
|
||||
typeid(p) == typeid(ArcTanh) || typeid(p) == typeid(AsType) ||
|
||||
typeid(p) == typeid(Ceil) || typeid(p) == typeid(Cos) ||
|
||||
typeid(p) == typeid(Cosh) || typeid(p) == typeid(Remainder) ||
|
||||
typeid(p) == typeid(Erf) || typeid(p) == typeid(ErfInv) ||
|
||||
typeid(p) == typeid(Exp) || typeid(p) == typeid(Floor) ||
|
||||
typeid(p) == typeid(Log) || typeid(p) == typeid(Log1p) ||
|
||||
typeid(p) == typeid(LogicalNot) || typeid(p) == typeid(Negative) ||
|
||||
typeid(p) == typeid(Round) || typeid(p) == typeid(Sigmoid) ||
|
||||
typeid(p) == typeid(Sign) || typeid(p) == typeid(Sin) ||
|
||||
typeid(p) == typeid(Sinh) || typeid(p) == typeid(Square) ||
|
||||
typeid(p) == typeid(Sqrt) || typeid(p) == typeid(Tan) ||
|
||||
typeid(p) == typeid(Tanh));
|
||||
}
|
||||
|
||||
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
|
||||
bool is_binary(const Primitive& p) {
|
||||
return (
|
||||
typeid(p) == typeid(Add) || typeid(p) == typeid(Divide) ||
|
||||
typeid(p) == typeid(Equal) || typeid(p) == typeid(Greater) ||
|
||||
typeid(p) == typeid(GreaterEqual) || typeid(p) == typeid(Less) ||
|
||||
typeid(p) == typeid(LessEqual) || typeid(p) == typeid(LogicalNot) ||
|
||||
typeid(p) == typeid(LogicalAnd) || typeid(p) == typeid(LogicalOr) ||
|
||||
typeid(p) == typeid(LogAddExp) || typeid(p) == typeid(Maximum) ||
|
||||
typeid(p) == typeid(Minimum) || typeid(p) == typeid(Multiply) ||
|
||||
typeid(p) == typeid(NotEqual) || typeid(p) == typeid(Power) ||
|
||||
typeid(p) == typeid(Subtract));
|
||||
}
|
||||
|
||||
bool is_broadcast(const Primitive& p) {
|
||||
return typeid(p) == typeid(Broadcast);
|
||||
}
|
||||
|
||||
bool is_noop(const Primitive& p) {
|
||||
return typeid(p) == typeid(Copy) || typeid(p) == typeid(StopGradient);
|
||||
}
|
||||
|
||||
bool is_fusable(const Primitive& p) {
|
||||
return is_unary(p) || is_binary(p) || is_broadcast(p) || is_noop(p);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
std::vector<array> compile_replace(
|
||||
const std::vector<array>& tape,
|
||||
const std::vector<array>& trace_inputs,
|
||||
const std::vector<array>& trace_outputs,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
} // namespace detail
|
||||
|
||||
Compiled::Compiled(
|
||||
Stream stream,
|
||||
std::vector<array> inputs,
|
||||
std::vector<array> outputs,
|
||||
std::vector<array> tape,
|
||||
std::unordered_set<uintptr_t> constant_ids)
|
||||
: Primitive(stream),
|
||||
inputs_(std::move(inputs)),
|
||||
outputs_(std::move(outputs)),
|
||||
tape_(std::move(tape)),
|
||||
constant_ids_(std::move(constant_ids)) {}
|
||||
|
||||
std::vector<array> Compiled::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
auto fun = [this](const std::vector<array>& inputs) {
|
||||
return detail::compile_replace(tape_, inputs_, outputs_, inputs);
|
||||
};
|
||||
|
||||
auto [_, vjps] = mlx::core::vjp(fun, primals, cotangents);
|
||||
std::vector<array> vjp_outs;
|
||||
for (int i = 0, j = 0; i < vjps.size(); ++i) {
|
||||
if (i < argnums.size() && i == argnums[j]) {
|
||||
vjp_outs.push_back(vjps[i]);
|
||||
j++;
|
||||
}
|
||||
}
|
||||
return vjp_outs;
|
||||
}
|
||||
|
||||
std::vector<array> Compiled::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
auto fun = [this](const std::vector<array>& inputs) {
|
||||
return detail::compile_replace(tape_, inputs_, outputs_, inputs);
|
||||
};
|
||||
|
||||
auto [_, jvps] = mlx::core::jvp(fun, primals, tangents);
|
||||
std::vector<array> jvp_outs;
|
||||
for (int i = 0, j = 0; i < jvps.size(); ++i) {
|
||||
if (i < argnums.size() && i == argnums[j]) {
|
||||
jvp_outs.push_back(jvps[i]);
|
||||
j++;
|
||||
}
|
||||
}
|
||||
return jvp_outs;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Compiled::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto fun = [this](const std::vector<array>& inputs) {
|
||||
return detail::compile_replace(tape_, inputs_, outputs_, inputs);
|
||||
};
|
||||
|
||||
auto outputs = mlx::core::vmap(fun, axes)(inputs);
|
||||
auto out_axes = std::vector<int>(outputs.size(), 0);
|
||||
return {outputs, out_axes};
|
||||
}
|
||||
|
||||
bool Compiled::is_equivalent(const Primitive& other) const {
|
||||
const Compiled& a_other = static_cast<const Compiled&>(other);
|
||||
return std::equal(
|
||||
tape_.begin(),
|
||||
tape_.end(),
|
||||
a_other.tape_.begin(),
|
||||
a_other.tape_.end(),
|
||||
[](const array& a1, const array& a2) {
|
||||
auto& p1 = a1.primitive();
|
||||
auto& p2 = a2.primitive();
|
||||
return typeid(p1) == typeid(p2) && p1.is_equivalent(p2);
|
||||
});
|
||||
}
|
||||
|
||||
void Compiled::print(std::ostream& os) {
|
||||
os << "Compiled";
|
||||
for (auto& a : tape_) {
|
||||
a.primitive().print(os);
|
||||
}
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
CompileMode& compile_mode() {
|
||||
auto get_val = []() {
|
||||
if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILE")) {
|
||||
return CompileMode::disabled;
|
||||
} else {
|
||||
return CompileMode::enabled;
|
||||
}
|
||||
};
|
||||
static CompileMode compile_mode_ = get_val();
|
||||
return compile_mode_;
|
||||
}
|
||||
|
||||
using CompileFn = std::function<std::vector<array>(const std::vector<array>&)>;
|
||||
using ParentsMap =
|
||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
|
||||
|
||||
// Helper that merges two arrays in the graph by setting the parents of the
|
||||
// source to point to the destination
|
||||
void merge(array& dst, array& src, ParentsMap& parents_map) {
|
||||
// Canonicalize the order of the primitives outputs
|
||||
auto sources = src.outputs();
|
||||
auto dests = dst.outputs();
|
||||
// For each src parent, point it to the corresponding dst
|
||||
for (int i = 0; i < sources.size(); ++i) {
|
||||
auto src_parents = parents_map.find(sources[i].id());
|
||||
if (src_parents == parents_map.end()) {
|
||||
continue;
|
||||
}
|
||||
auto& pairs = parents_map[dests[i].id()];
|
||||
for (auto& parent : src_parents->second) {
|
||||
parent.first.inputs()[parent.second] = dests[i];
|
||||
pairs.push_back(parent);
|
||||
}
|
||||
// Remove the source from the map to avoid fusing with it again
|
||||
parents_map.erase(src_parents);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename... U>
|
||||
size_t getAddress(std::function<T(U...)> f) {
|
||||
typedef T(fnType)(U...);
|
||||
@@ -59,9 +221,10 @@ struct CompilerCache {
|
||||
auto is_match = [](const std::vector<array>& in1,
|
||||
const std::vector<array>& in2) {
|
||||
if (in1.size() != in2.size()) {
|
||||
throw std::runtime_error(
|
||||
"[compiler] Got different number of inputs to function,"
|
||||
" this should never happen.");
|
||||
std::ostringstream msg;
|
||||
msg << "[compiler] Unexpected number of inputs to compiled function:"
|
||||
<< " expected " << in2.size() << " got " << in1.size() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
for (int i = 0; i < in1.size(); ++i) {
|
||||
if (in1[i].shape() != in2[i].shape()) {
|
||||
@@ -205,28 +368,6 @@ void compile_simplify(
|
||||
}
|
||||
}
|
||||
|
||||
// Helper that fuses two arrays in the graph by setting the parents of the
|
||||
// source to point to the destination
|
||||
auto fuse = [&](array& dst, array& src) {
|
||||
// Canonicalize the order of the primitives outputs
|
||||
auto sources = src.outputs();
|
||||
auto dests = dst.outputs();
|
||||
// For each src parent, point it to the corresponding dest
|
||||
for (int i = 0; i < sources.size(); ++i) {
|
||||
auto src_parents = parents_map.find(sources[i].id());
|
||||
if (src_parents == parents_map.end()) {
|
||||
continue;
|
||||
}
|
||||
auto& pairs = parents_map[dests[i].id()];
|
||||
for (auto& parent : src_parents->second) {
|
||||
parent.first.inputs()[parent.second] = dests[i];
|
||||
pairs.push_back(parent);
|
||||
}
|
||||
// Remove the source from the map to avoid fusing with it again
|
||||
parents_map.erase(src_parents);
|
||||
}
|
||||
};
|
||||
|
||||
// Depth-1 array equivalence check.
|
||||
auto array_equivalent = [](const array& a, const array& b) {
|
||||
if (!a.has_primitive() || !b.has_primitive()) {
|
||||
@@ -254,33 +395,32 @@ void compile_simplify(
|
||||
return pa.is_equivalent(pb);
|
||||
};
|
||||
|
||||
// Pass 0: fuse scalars
|
||||
// Merge scalars
|
||||
std::vector<array> new_tape;
|
||||
for (auto& arr : tape) {
|
||||
// Check if we can fuse scalars
|
||||
// Check if we can merge scalars
|
||||
if (is_scalar(arr)) {
|
||||
auto scalar = scalars.find(get_scalar_rep(arr));
|
||||
if (scalar->second.id() != arr.id()) {
|
||||
fuse(scalar->second, arr);
|
||||
merge(scalar->second, arr, parents_map);
|
||||
// Don't keep orphaned scalars in the tape
|
||||
continue;
|
||||
}
|
||||
}
|
||||
new_tape.push_back(std::move(arr));
|
||||
}
|
||||
|
||||
tape = std::move(new_tape);
|
||||
|
||||
std::unordered_set<uintptr_t> output_set;
|
||||
for (auto& o : outputs) {
|
||||
output_set.insert(o.id());
|
||||
}
|
||||
// Pass 1..passes: fuse only keeping non-orphaned arrays in the tape
|
||||
// Multi-pass merge only keeping non-orphaned arrays in the tape
|
||||
for (int pass = 0; pass < passes; ++pass) {
|
||||
for (auto& arr : tape) {
|
||||
// Helper to check if we can fuse the parents of the
|
||||
// Helper to check if we can merge the parents of the
|
||||
// given array
|
||||
auto maybe_fuse_parents = [&](auto& a) {
|
||||
auto maybe_merge_parents = [&](auto& a) {
|
||||
auto parents = parents_map.find(a.id());
|
||||
if (parents != parents_map.end()) {
|
||||
auto N = parents->second.size();
|
||||
@@ -296,7 +436,7 @@ void compile_simplify(
|
||||
auto& src = parents->second[j].first;
|
||||
auto& dst = parents->second[i].first;
|
||||
if (src.id() != dst.id() && array_equivalent(src, dst)) {
|
||||
fuse(dst, src);
|
||||
merge(dst, src, parents_map);
|
||||
mask[j] = true;
|
||||
}
|
||||
}
|
||||
@@ -313,9 +453,9 @@ void compile_simplify(
|
||||
}
|
||||
};
|
||||
|
||||
bool discard = maybe_fuse_parents(arr);
|
||||
bool discard = maybe_merge_parents(arr);
|
||||
for (auto& s : arr.siblings()) {
|
||||
discard &= maybe_fuse_parents(s);
|
||||
discard &= maybe_merge_parents(s);
|
||||
}
|
||||
// If an array and its siblings have no parents, and none of them are
|
||||
// outputs, it is safe to remove it from the tape
|
||||
@@ -327,6 +467,216 @@ void compile_simplify(
|
||||
}
|
||||
}
|
||||
|
||||
// Extract sub-graphs of the graph that can be compiled
|
||||
// and replace them with a Compiled Primitive.
|
||||
void compile_fuse(
|
||||
std::vector<array>& tape,
|
||||
ParentsMap& parents_map,
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
// Track outputs to replace with new compiled outputs
|
||||
std::unordered_map<uintptr_t, array> output_map;
|
||||
for (auto& o : outputs) {
|
||||
output_map.insert({o.id(), o});
|
||||
}
|
||||
|
||||
// Set of inputs to distinguish constants
|
||||
std::unordered_set<uintptr_t> input_ids;
|
||||
for (auto& in : inputs) {
|
||||
input_ids.insert(in.id());
|
||||
}
|
||||
|
||||
// Go through the tape in reverse order and check for fusable sub-graphs
|
||||
std::vector<array> new_tape;
|
||||
std::unordered_set<uintptr_t> global_cache;
|
||||
for (int i = tape.size() - 1; i >= 0; --i) {
|
||||
auto& arr = tape[i];
|
||||
|
||||
// Already compiled
|
||||
if (global_cache.find(arr.id()) != global_cache.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Two pass recursion:
|
||||
// First pass:
|
||||
// - Collect all the primitives which we can fuse with
|
||||
// - Keeps a cache of fusable primitives which may be added out of
|
||||
// DAG order. We have to determine if all of a fused primitive's
|
||||
// outputs are also in the fused section, and this may not be the
|
||||
// case the first time we visit it.
|
||||
// Second pass:
|
||||
// - Collect inputs to the new compiled primitive
|
||||
// - Add fusable primitives to a tape in the correct order
|
||||
|
||||
std::function<void(const array&, int, const Stream&)> recurse;
|
||||
std::unordered_set<uintptr_t> cache;
|
||||
recurse = [&](const array& a, int depth, const Stream& s) {
|
||||
if (cache.find(a.id()) != cache.end()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Stop fusing if:
|
||||
// - Depth limit exceeded
|
||||
// - Constant input
|
||||
// - Stream mismatch
|
||||
// - Non fusable primitive
|
||||
if (depth >= max_compile_depth || !a.has_primitive() ||
|
||||
a.primitive().stream() != s || !is_fusable(a.primitive())) {
|
||||
return;
|
||||
}
|
||||
|
||||
bool all_parents_in = true;
|
||||
if (depth > 0) {
|
||||
// Guaranteed to have a parent since nested in the
|
||||
// recursion.
|
||||
auto& parents = parents_map.at(a.id());
|
||||
for (auto& [p, idx] : parents) {
|
||||
auto in_cache = cache.find(p.id()) != cache.end();
|
||||
if (!in_cache) {
|
||||
all_parents_in = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Arrays with a mix of parents outside the compilable section
|
||||
// are not fusable
|
||||
if (!all_parents_in) {
|
||||
return;
|
||||
}
|
||||
|
||||
cache.insert({a.id()});
|
||||
|
||||
for (auto& in : a.inputs()) {
|
||||
recurse(in, depth + 1, s);
|
||||
}
|
||||
};
|
||||
|
||||
if (arr.has_primitive()) {
|
||||
Stream s = arr.primitive().stream();
|
||||
recurse(arr, 0, s);
|
||||
}
|
||||
|
||||
// Not worth fusing a single primitive
|
||||
if (cache.size() <= 1) {
|
||||
new_tape.push_back(arr);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Recurse a second time to build the tape in the right
|
||||
// order and collect the inputs
|
||||
std::unordered_set<uintptr_t> input_set;
|
||||
std::vector<array> inputs;
|
||||
std::vector<array> fused_tape;
|
||||
std::unordered_set<uintptr_t> tape_set;
|
||||
std::function<void(const array&)> recurse_tape;
|
||||
recurse_tape = [&](const array& a) {
|
||||
if (cache.find(a.id()) == cache.end()) {
|
||||
if (input_set.find(a.id()) == input_set.end()) {
|
||||
input_set.insert(a.id());
|
||||
inputs.push_back(a);
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (tape_set.find(a.id()) != tape_set.end()) {
|
||||
return;
|
||||
}
|
||||
tape_set.insert(a.id());
|
||||
for (auto& in : a.inputs()) {
|
||||
recurse_tape(in);
|
||||
}
|
||||
fused_tape.push_back(a);
|
||||
};
|
||||
recurse_tape(arr);
|
||||
|
||||
std::vector<array> old_outputs;
|
||||
// Add to global cache and add any global outputs to outputs
|
||||
// of new primitive
|
||||
for (int j = 0; j < fused_tape.size() - 1; ++j) {
|
||||
auto& f = fused_tape[j];
|
||||
if (output_map.find(f.id()) != output_map.end()) {
|
||||
old_outputs.push_back(f);
|
||||
// Parents are now siblings, update the parent map
|
||||
auto& pairs = parents_map[f.id()];
|
||||
pairs.erase(
|
||||
std::remove_if(
|
||||
pairs.begin(),
|
||||
pairs.end(),
|
||||
[&](auto& p) {
|
||||
return cache.find(p.first.id()) != cache.end();
|
||||
}),
|
||||
pairs.end());
|
||||
} else {
|
||||
// Remove inner fused arrays parents from the parents map
|
||||
// to keep the parents map in a valid state
|
||||
parents_map.erase(f.id());
|
||||
}
|
||||
global_cache.insert({f.id()});
|
||||
}
|
||||
old_outputs.push_back(arr);
|
||||
|
||||
std::vector<std::vector<int>> shapes;
|
||||
std::vector<Dtype> types;
|
||||
for (auto& o : old_outputs) {
|
||||
shapes.push_back(o.shape());
|
||||
types.push_back(o.dtype());
|
||||
}
|
||||
std::unordered_set<uintptr_t> constant_ids;
|
||||
for (auto& in : inputs) {
|
||||
// Scalar constant
|
||||
if (in.size() == 1 && !in.has_primitive() &&
|
||||
input_ids.find(in.id()) == input_ids.end()) {
|
||||
constant_ids.insert(in.id());
|
||||
}
|
||||
}
|
||||
auto compiled_outputs = array::make_arrays(
|
||||
shapes,
|
||||
types,
|
||||
std::make_shared<Compiled>(
|
||||
outputs.back().primitive().stream(),
|
||||
inputs,
|
||||
old_outputs,
|
||||
std::move(fused_tape),
|
||||
std::move(constant_ids)),
|
||||
inputs);
|
||||
|
||||
// One output per primitive
|
||||
new_tape.push_back(compiled_outputs.back());
|
||||
|
||||
// Replace inputs old parents with compiled_outputs
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
auto& pairs = parents_map[inputs[i].id()];
|
||||
pairs.erase(
|
||||
std::remove_if(
|
||||
pairs.begin(),
|
||||
pairs.end(),
|
||||
[&](auto& p) { return cache.find(p.first.id()) != cache.end(); }),
|
||||
pairs.end());
|
||||
for (auto& o : compiled_outputs) {
|
||||
pairs.push_back({o, i});
|
||||
}
|
||||
}
|
||||
|
||||
// - Update outputs parents to point to compiled outputs
|
||||
// - Update any overall graph outputs to be compiled outputs
|
||||
for (int o = 0; o < old_outputs.size(); ++o) {
|
||||
merge(compiled_outputs[o], old_outputs[o], parents_map);
|
||||
if (auto it = output_map.find(old_outputs[o].id());
|
||||
it != output_map.end()) {
|
||||
it->second = compiled_outputs[o];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::reverse(new_tape.begin(), new_tape.end());
|
||||
tape = std::move(new_tape);
|
||||
|
||||
// Replace output with potentially compiled output
|
||||
for (auto& o : outputs) {
|
||||
o = output_map.at(o.id());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> compile_replace(
|
||||
const std::vector<array>& tape,
|
||||
const std::vector<array>& trace_inputs,
|
||||
@@ -380,7 +730,7 @@ std::vector<array> compile_replace(
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
size_t fun_id) {
|
||||
if (compiler_disabled()) {
|
||||
if (compile_mode() == CompileMode::disabled) {
|
||||
return fun;
|
||||
}
|
||||
return [fun, fun_id](const std::vector<array>& inputs) {
|
||||
@@ -402,10 +752,16 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
compile_dfs(entry.inputs, entry.outputs);
|
||||
|
||||
// Simplify the tape
|
||||
compile_simplify(entry.tape, parents_map, entry.outputs, /* passes */ 3);
|
||||
if (compile_mode() != CompileMode::no_simplify) {
|
||||
compile_simplify(
|
||||
entry.tape, parents_map, entry.outputs, /* passes */ 3);
|
||||
}
|
||||
|
||||
// This is a good point to do more optimizations, e.g. kernel fusion to
|
||||
// generate new primitives. The tape needs to be updated accordingly
|
||||
// Kernel fusion to generate Compiled primitives. The tape and
|
||||
// new outputs must be updated accordingly
|
||||
if (compile_mode() != CompileMode::no_fuse) {
|
||||
compile_fuse(entry.tape, parents_map, entry.inputs, entry.outputs);
|
||||
}
|
||||
}
|
||||
|
||||
// At this point we must have a tape, now replace the placeholders
|
||||
@@ -422,7 +778,7 @@ void compile_erase(size_t fun_id) {
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun) {
|
||||
if (detail::compiler_disabled()) {
|
||||
if (detail::compile_mode() == CompileMode::disabled) {
|
||||
return fun;
|
||||
}
|
||||
auto fun_id = detail::getAddress(fun);
|
||||
@@ -430,11 +786,15 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
}
|
||||
|
||||
void disable_compile() {
|
||||
detail::compiler_disabled() = true;
|
||||
detail::compile_mode() = CompileMode::disabled;
|
||||
}
|
||||
|
||||
void enable_compile() {
|
||||
detail::compiler_disabled() = false;
|
||||
detail::compile_mode() = CompileMode::enabled;
|
||||
}
|
||||
|
||||
void set_compile_mode(CompileMode mode) {
|
||||
detail::compile_mode() = mode;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user