Compile primitive (#571)

* Compiled primitive with basic binary, unary graph-level fusion
This commit is contained in:
Awni Hannun 2024-02-05 06:51:22 -08:00 committed by GitHub
parent 31fea3758e
commit d75ae52ecd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 1088 additions and 75 deletions

View File

@ -33,10 +33,12 @@ DEFAULT(ArgSort)
DEFAULT(AsStrided)
DEFAULT(Broadcast)
DEFAULT(Ceil)
DEFAULT_MULTI(Compiled)
DEFAULT(Concatenate)
DEFAULT(Copy)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)
DEFAULT_MULTI(DivMod)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
@ -57,6 +59,7 @@ DEFAULT(Minimum)
DEFAULT(NotEqual)
DEFAULT(Pad)
DEFAULT(Partition)
DEFAULT_MULTI(QRF)
DEFAULT(RandomBits)
DEFAULT(Reshape)
DEFAULT(Round)
@ -68,8 +71,6 @@ DEFAULT_MULTI(Split)
DEFAULT(Sort)
DEFAULT(StopGradient)
DEFAULT(Transpose)
DEFAULT_MULTI(DivMod)
DEFAULT_MULTI(QRF)
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);

View File

@ -3,6 +3,7 @@ target_sources(
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp

View File

@ -0,0 +1,59 @@
// Copyright © 2023-2024 Apple Inc.
#include <queue>
#include "mlx/primitives.h"
namespace mlx::core {
// Build the real tape
std::pair<std::queue<array>, std::vector<array>> trace_to_real(
const std::vector<array>& trace_tape,
const std::vector<array>& trace_inputs,
const std::vector<array>& trace_outputs,
const std::vector<array>& inputs) {
std::unordered_map<uintptr_t, array> trace_to_real;
for (int i = 0; i < inputs.size(); ++i) {
trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
}
std::queue<array> tape;
for (auto& a : trace_tape) {
// Find real inputs
std::vector<array> real_inputs;
for (auto& in : a.inputs()) {
real_inputs.push_back(trace_to_real.at(in.id()));
}
tape.push(
array(a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs)));
trace_to_real.insert({a.id(), tape.back()});
}
std::vector<array> outputs;
for (auto& o : trace_outputs) {
outputs.push_back(trace_to_real.at(o.id()));
}
return {tape, outputs};
}
void Compiled::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
// Make the a real tape from the tracers
auto [tape, real_outputs] = trace_to_real(tape_, inputs_, outputs_, inputs);
// Run the tape
while (!tape.empty()) {
auto a = std::move(tape.front());
tape.pop();
auto outputs = a.outputs();
a.primitive().eval_cpu(a.inputs(), outputs);
a.detach();
}
// Copy results into outputs
for (int o = 0; o < real_outputs.size(); ++o) {
outputs[o].copy_shared_buffer(real_outputs[o]);
}
}
} // namespace mlx::core

View File

@ -41,7 +41,9 @@ DEFAULT(ArgSort)
DEFAULT(AsType)
DEFAULT(AsStrided)
DEFAULT(Broadcast)
DEFAULT_MULTI(DivMod)
DEFAULT(Ceil)
DEFAULT_MULTI(Compiled)
DEFAULT(Concatenate)
DEFAULT(Convolution)
DEFAULT(Copy)
@ -78,6 +80,7 @@ DEFAULT(NotEqual)
DEFAULT(Pad)
DEFAULT(Partition)
DEFAULT(Power)
DEFAULT_MULTI(QRF)
DEFAULT(QuantizedMatmul)
DEFAULT(RandomBits)
DEFAULT(Reduce)
@ -100,8 +103,6 @@ DEFAULT(Subtract)
DEFAULT(Tan)
DEFAULT(Tanh)
DEFAULT(Transpose)
DEFAULT_MULTI(DivMod)
DEFAULT_MULTI(QRF)
namespace {

View File

@ -2,6 +2,7 @@ target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp

View File

@ -0,0 +1,44 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/device.h"
#include "mlx/primitives.h"
namespace mlx::core {
void Compiled::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
// Just a fall-back to the original tape for now
std::unordered_map<uintptr_t, array> trace_to_real;
for (int i = 0; i < inputs.size(); ++i) {
trace_to_real.insert({inputs_[i].id(), inputs[i]});
}
for (int i = 0; i < outputs.size(); ++i) {
trace_to_real.insert({outputs_[i].id(), outputs[i]});
}
for (auto& a : tape_) {
std::vector<array> p_inputs;
for (auto& in : a.inputs()) {
p_inputs.push_back(trace_to_real.at(in.id()));
}
// If a is an output get it from the map, otherwise create it
// NB this is safe as long as no multi-output sub primitves are allowed
// in Compiled
std::vector<array> p_outputs;
if (auto it = trace_to_real.find(a.id()); it != trace_to_real.end()) {
p_outputs.push_back(it->second);
} else {
p_outputs.push_back(array(a.shape(), a.dtype(), a.primitive_ptr(), {}));
trace_to_real.insert({a.id(), p_outputs[0]});
}
a.primitive().eval_gpu(p_inputs, p_outputs);
}
auto& s = stream();
auto& d = metal::device(s.device);
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->addCompletedHandler(
[trace_to_real](MTL::CommandBuffer*) mutable {});
}
} // namespace mlx::core

View File

@ -32,6 +32,7 @@ NO_GPU(AsType)
NO_GPU(AsStrided)
NO_GPU(Broadcast)
NO_GPU(Ceil)
NO_GPU_MULTI(Compiled)
NO_GPU(Concatenate)
NO_GPU(Convolution)
NO_GPU(Copy)
@ -40,6 +41,7 @@ NO_GPU(Cosh)
NO_GPU_MULTI(CustomVJP)
NO_GPU_MULTI(Depends)
NO_GPU(Divide)
NO_GPU_MULTI(DivMod)
NO_GPU(Remainder)
NO_GPU(Equal)
NO_GPU(Erf)
@ -69,6 +71,7 @@ NO_GPU(NotEqual)
NO_GPU(Pad)
NO_GPU(Partition)
NO_GPU(Power)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)
NO_GPU(RandomBits)
NO_GPU(Reduce)
@ -91,6 +94,5 @@ NO_GPU(Subtract)
NO_GPU(Tan)
NO_GPU(Tanh)
NO_GPU(Transpose)
NO_GPU_MULTI(DivMod)
NO_GPU_MULTI(QRF)
} // namespace mlx::core

View File

@ -1,36 +1,198 @@
// Copyright © 2023-2024 Apple Inc.
#include <cstdlib>
#include <map>
#include <unordered_map>
#include <unordered_set>
#include "mlx/allocator.h"
#include "mlx/compile.h"
#include "mlx/primitives.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
namespace mlx::core {
namespace detail {
constexpr int max_compile_depth = 6;
bool& compiler_disabled() {
auto get_val = []() {
if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILE")) {
return true;
} else {
return false;
}
};
static bool compiler_disabled_ = get_val();
return compiler_disabled_;
bool is_unary(const Primitive& p) {
return (
typeid(p) == typeid(Abs) || typeid(p) == typeid(ArcCos) ||
typeid(p) == typeid(ArcCosh) || typeid(p) == typeid(ArcSin) ||
typeid(p) == typeid(ArcSinh) || typeid(p) == typeid(ArcTan) ||
typeid(p) == typeid(ArcTanh) || typeid(p) == typeid(AsType) ||
typeid(p) == typeid(Ceil) || typeid(p) == typeid(Cos) ||
typeid(p) == typeid(Cosh) || typeid(p) == typeid(Remainder) ||
typeid(p) == typeid(Erf) || typeid(p) == typeid(ErfInv) ||
typeid(p) == typeid(Exp) || typeid(p) == typeid(Floor) ||
typeid(p) == typeid(Log) || typeid(p) == typeid(Log1p) ||
typeid(p) == typeid(LogicalNot) || typeid(p) == typeid(Negative) ||
typeid(p) == typeid(Round) || typeid(p) == typeid(Sigmoid) ||
typeid(p) == typeid(Sign) || typeid(p) == typeid(Sin) ||
typeid(p) == typeid(Sinh) || typeid(p) == typeid(Square) ||
typeid(p) == typeid(Sqrt) || typeid(p) == typeid(Tan) ||
typeid(p) == typeid(Tanh));
}
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
bool is_binary(const Primitive& p) {
return (
typeid(p) == typeid(Add) || typeid(p) == typeid(Divide) ||
typeid(p) == typeid(Equal) || typeid(p) == typeid(Greater) ||
typeid(p) == typeid(GreaterEqual) || typeid(p) == typeid(Less) ||
typeid(p) == typeid(LessEqual) || typeid(p) == typeid(LogicalNot) ||
typeid(p) == typeid(LogicalAnd) || typeid(p) == typeid(LogicalOr) ||
typeid(p) == typeid(LogAddExp) || typeid(p) == typeid(Maximum) ||
typeid(p) == typeid(Minimum) || typeid(p) == typeid(Multiply) ||
typeid(p) == typeid(NotEqual) || typeid(p) == typeid(Power) ||
typeid(p) == typeid(Subtract));
}
bool is_broadcast(const Primitive& p) {
return typeid(p) == typeid(Broadcast);
}
bool is_noop(const Primitive& p) {
return typeid(p) == typeid(Copy) || typeid(p) == typeid(StopGradient);
}
bool is_fusable(const Primitive& p) {
return is_unary(p) || is_binary(p) || is_broadcast(p) || is_noop(p);
}
namespace detail {
std::vector<array> compile_replace(
const std::vector<array>& tape,
const std::vector<array>& trace_inputs,
const std::vector<array>& trace_outputs,
const std::vector<array>& inputs);
} // namespace detail
Compiled::Compiled(
Stream stream,
std::vector<array> inputs,
std::vector<array> outputs,
std::vector<array> tape,
std::unordered_set<uintptr_t> constant_ids)
: Primitive(stream),
inputs_(std::move(inputs)),
outputs_(std::move(outputs)),
tape_(std::move(tape)),
constant_ids_(std::move(constant_ids)) {}
std::vector<array> Compiled::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
auto fun = [this](const std::vector<array>& inputs) {
return detail::compile_replace(tape_, inputs_, outputs_, inputs);
};
auto [_, vjps] = mlx::core::vjp(fun, primals, cotangents);
std::vector<array> vjp_outs;
for (int i = 0, j = 0; i < vjps.size(); ++i) {
if (i < argnums.size() && i == argnums[j]) {
vjp_outs.push_back(vjps[i]);
j++;
}
}
return vjp_outs;
}
std::vector<array> Compiled::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
auto fun = [this](const std::vector<array>& inputs) {
return detail::compile_replace(tape_, inputs_, outputs_, inputs);
};
auto [_, jvps] = mlx::core::jvp(fun, primals, tangents);
std::vector<array> jvp_outs;
for (int i = 0, j = 0; i < jvps.size(); ++i) {
if (i < argnums.size() && i == argnums[j]) {
jvp_outs.push_back(jvps[i]);
j++;
}
}
return jvp_outs;
}
std::pair<std::vector<array>, std::vector<int>> Compiled::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto fun = [this](const std::vector<array>& inputs) {
return detail::compile_replace(tape_, inputs_, outputs_, inputs);
};
auto outputs = mlx::core::vmap(fun, axes)(inputs);
auto out_axes = std::vector<int>(outputs.size(), 0);
return {outputs, out_axes};
}
bool Compiled::is_equivalent(const Primitive& other) const {
const Compiled& a_other = static_cast<const Compiled&>(other);
return std::equal(
tape_.begin(),
tape_.end(),
a_other.tape_.begin(),
a_other.tape_.end(),
[](const array& a1, const array& a2) {
auto& p1 = a1.primitive();
auto& p2 = a2.primitive();
return typeid(p1) == typeid(p2) && p1.is_equivalent(p2);
});
}
void Compiled::print(std::ostream& os) {
os << "Compiled";
for (auto& a : tape_) {
a.primitive().print(os);
}
}
namespace detail {
CompileMode& compile_mode() {
auto get_val = []() {
if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILE")) {
return CompileMode::disabled;
} else {
return CompileMode::enabled;
}
};
static CompileMode compile_mode_ = get_val();
return compile_mode_;
}
using CompileFn = std::function<std::vector<array>(const std::vector<array>&)>;
using ParentsMap =
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
// Helper that merges two arrays in the graph by setting the parents of the
// source to point to the destination
void merge(array& dst, array& src, ParentsMap& parents_map) {
// Canonicalize the order of the primitives outputs
auto sources = src.outputs();
auto dests = dst.outputs();
// For each src parent, point it to the corresponding dst
for (int i = 0; i < sources.size(); ++i) {
auto src_parents = parents_map.find(sources[i].id());
if (src_parents == parents_map.end()) {
continue;
}
auto& pairs = parents_map[dests[i].id()];
for (auto& parent : src_parents->second) {
parent.first.inputs()[parent.second] = dests[i];
pairs.push_back(parent);
}
// Remove the source from the map to avoid fusing with it again
parents_map.erase(src_parents);
}
};
template <typename T, typename... U>
size_t getAddress(std::function<T(U...)> f) {
typedef T(fnType)(U...);
@ -59,9 +221,10 @@ struct CompilerCache {
auto is_match = [](const std::vector<array>& in1,
const std::vector<array>& in2) {
if (in1.size() != in2.size()) {
throw std::runtime_error(
"[compiler] Got different number of inputs to function,"
" this should never happen.");
std::ostringstream msg;
msg << "[compiler] Unexpected number of inputs to compiled function:"
<< " expected " << in2.size() << " got " << in1.size() << ".";
throw std::invalid_argument(msg.str());
}
for (int i = 0; i < in1.size(); ++i) {
if (in1[i].shape() != in2[i].shape()) {
@ -205,28 +368,6 @@ void compile_simplify(
}
}
// Helper that fuses two arrays in the graph by setting the parents of the
// source to point to the destination
auto fuse = [&](array& dst, array& src) {
// Canonicalize the order of the primitives outputs
auto sources = src.outputs();
auto dests = dst.outputs();
// For each src parent, point it to the corresponding dest
for (int i = 0; i < sources.size(); ++i) {
auto src_parents = parents_map.find(sources[i].id());
if (src_parents == parents_map.end()) {
continue;
}
auto& pairs = parents_map[dests[i].id()];
for (auto& parent : src_parents->second) {
parent.first.inputs()[parent.second] = dests[i];
pairs.push_back(parent);
}
// Remove the source from the map to avoid fusing with it again
parents_map.erase(src_parents);
}
};
// Depth-1 array equivalence check.
auto array_equivalent = [](const array& a, const array& b) {
if (!a.has_primitive() || !b.has_primitive()) {
@ -254,33 +395,32 @@ void compile_simplify(
return pa.is_equivalent(pb);
};
// Pass 0: fuse scalars
// Merge scalars
std::vector<array> new_tape;
for (auto& arr : tape) {
// Check if we can fuse scalars
// Check if we can merge scalars
if (is_scalar(arr)) {
auto scalar = scalars.find(get_scalar_rep(arr));
if (scalar->second.id() != arr.id()) {
fuse(scalar->second, arr);
merge(scalar->second, arr, parents_map);
// Don't keep orphaned scalars in the tape
continue;
}
}
new_tape.push_back(std::move(arr));
}
tape = std::move(new_tape);
std::unordered_set<uintptr_t> output_set;
for (auto& o : outputs) {
output_set.insert(o.id());
}
// Pass 1..passes: fuse only keeping non-orphaned arrays in the tape
// Multi-pass merge only keeping non-orphaned arrays in the tape
for (int pass = 0; pass < passes; ++pass) {
for (auto& arr : tape) {
// Helper to check if we can fuse the parents of the
// Helper to check if we can merge the parents of the
// given array
auto maybe_fuse_parents = [&](auto& a) {
auto maybe_merge_parents = [&](auto& a) {
auto parents = parents_map.find(a.id());
if (parents != parents_map.end()) {
auto N = parents->second.size();
@ -296,7 +436,7 @@ void compile_simplify(
auto& src = parents->second[j].first;
auto& dst = parents->second[i].first;
if (src.id() != dst.id() && array_equivalent(src, dst)) {
fuse(dst, src);
merge(dst, src, parents_map);
mask[j] = true;
}
}
@ -313,9 +453,9 @@ void compile_simplify(
}
};
bool discard = maybe_fuse_parents(arr);
bool discard = maybe_merge_parents(arr);
for (auto& s : arr.siblings()) {
discard &= maybe_fuse_parents(s);
discard &= maybe_merge_parents(s);
}
// If an array and its siblings have no parents, and none of them are
// outputs, it is safe to remove it from the tape
@ -327,6 +467,216 @@ void compile_simplify(
}
}
// Extract sub-graphs of the graph that can be compiled
// and replace them with a Compiled Primitive.
void compile_fuse(
std::vector<array>& tape,
ParentsMap& parents_map,
const std::vector<array>& inputs,
std::vector<array>& outputs) {
// Track outputs to replace with new compiled outputs
std::unordered_map<uintptr_t, array> output_map;
for (auto& o : outputs) {
output_map.insert({o.id(), o});
}
// Set of inputs to distinguish constants
std::unordered_set<uintptr_t> input_ids;
for (auto& in : inputs) {
input_ids.insert(in.id());
}
// Go through the tape in reverse order and check for fusable sub-graphs
std::vector<array> new_tape;
std::unordered_set<uintptr_t> global_cache;
for (int i = tape.size() - 1; i >= 0; --i) {
auto& arr = tape[i];
// Already compiled
if (global_cache.find(arr.id()) != global_cache.end()) {
continue;
}
// Two pass recursion:
// First pass:
// - Collect all the primitives which we can fuse with
// - Keeps a cache of fusable primitives which may be added out of
// DAG order. We have to determine if all of a fused primitive's
// outputs are also in the fused section, and this may not be the
// case the first time we visit it.
// Second pass:
// - Collect inputs to the new compiled primitive
// - Add fusable primitives to a tape in the correct order
std::function<void(const array&, int, const Stream&)> recurse;
std::unordered_set<uintptr_t> cache;
recurse = [&](const array& a, int depth, const Stream& s) {
if (cache.find(a.id()) != cache.end()) {
return;
}
// Stop fusing if:
// - Depth limit exceeded
// - Constant input
// - Stream mismatch
// - Non fusable primitive
if (depth >= max_compile_depth || !a.has_primitive() ||
a.primitive().stream() != s || !is_fusable(a.primitive())) {
return;
}
bool all_parents_in = true;
if (depth > 0) {
// Guaranteed to have a parent since nested in the
// recursion.
auto& parents = parents_map.at(a.id());
for (auto& [p, idx] : parents) {
auto in_cache = cache.find(p.id()) != cache.end();
if (!in_cache) {
all_parents_in = false;
break;
}
}
}
// Arrays with a mix of parents outside the compilable section
// are not fusable
if (!all_parents_in) {
return;
}
cache.insert({a.id()});
for (auto& in : a.inputs()) {
recurse(in, depth + 1, s);
}
};
if (arr.has_primitive()) {
Stream s = arr.primitive().stream();
recurse(arr, 0, s);
}
// Not worth fusing a single primitive
if (cache.size() <= 1) {
new_tape.push_back(arr);
continue;
}
// Recurse a second time to build the tape in the right
// order and collect the inputs
std::unordered_set<uintptr_t> input_set;
std::vector<array> inputs;
std::vector<array> fused_tape;
std::unordered_set<uintptr_t> tape_set;
std::function<void(const array&)> recurse_tape;
recurse_tape = [&](const array& a) {
if (cache.find(a.id()) == cache.end()) {
if (input_set.find(a.id()) == input_set.end()) {
input_set.insert(a.id());
inputs.push_back(a);
}
return;
}
if (tape_set.find(a.id()) != tape_set.end()) {
return;
}
tape_set.insert(a.id());
for (auto& in : a.inputs()) {
recurse_tape(in);
}
fused_tape.push_back(a);
};
recurse_tape(arr);
std::vector<array> old_outputs;
// Add to global cache and add any global outputs to outputs
// of new primitive
for (int j = 0; j < fused_tape.size() - 1; ++j) {
auto& f = fused_tape[j];
if (output_map.find(f.id()) != output_map.end()) {
old_outputs.push_back(f);
// Parents are now siblings, update the parent map
auto& pairs = parents_map[f.id()];
pairs.erase(
std::remove_if(
pairs.begin(),
pairs.end(),
[&](auto& p) {
return cache.find(p.first.id()) != cache.end();
}),
pairs.end());
} else {
// Remove inner fused arrays parents from the parents map
// to keep the parents map in a valid state
parents_map.erase(f.id());
}
global_cache.insert({f.id()});
}
old_outputs.push_back(arr);
std::vector<std::vector<int>> shapes;
std::vector<Dtype> types;
for (auto& o : old_outputs) {
shapes.push_back(o.shape());
types.push_back(o.dtype());
}
std::unordered_set<uintptr_t> constant_ids;
for (auto& in : inputs) {
// Scalar constant
if (in.size() == 1 && !in.has_primitive() &&
input_ids.find(in.id()) == input_ids.end()) {
constant_ids.insert(in.id());
}
}
auto compiled_outputs = array::make_arrays(
shapes,
types,
std::make_shared<Compiled>(
outputs.back().primitive().stream(),
inputs,
old_outputs,
std::move(fused_tape),
std::move(constant_ids)),
inputs);
// One output per primitive
new_tape.push_back(compiled_outputs.back());
// Replace inputs old parents with compiled_outputs
for (int i = 0; i < inputs.size(); ++i) {
auto& pairs = parents_map[inputs[i].id()];
pairs.erase(
std::remove_if(
pairs.begin(),
pairs.end(),
[&](auto& p) { return cache.find(p.first.id()) != cache.end(); }),
pairs.end());
for (auto& o : compiled_outputs) {
pairs.push_back({o, i});
}
}
// - Update outputs parents to point to compiled outputs
// - Update any overall graph outputs to be compiled outputs
for (int o = 0; o < old_outputs.size(); ++o) {
merge(compiled_outputs[o], old_outputs[o], parents_map);
if (auto it = output_map.find(old_outputs[o].id());
it != output_map.end()) {
it->second = compiled_outputs[o];
}
}
}
std::reverse(new_tape.begin(), new_tape.end());
tape = std::move(new_tape);
// Replace output with potentially compiled output
for (auto& o : outputs) {
o = output_map.at(o.id());
}
}
std::vector<array> compile_replace(
const std::vector<array>& tape,
const std::vector<array>& trace_inputs,
@ -380,7 +730,7 @@ std::vector<array> compile_replace(
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
size_t fun_id) {
if (compiler_disabled()) {
if (compile_mode() == CompileMode::disabled) {
return fun;
}
return [fun, fun_id](const std::vector<array>& inputs) {
@ -402,10 +752,16 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
compile_dfs(entry.inputs, entry.outputs);
// Simplify the tape
compile_simplify(entry.tape, parents_map, entry.outputs, /* passes */ 3);
if (compile_mode() != CompileMode::no_simplify) {
compile_simplify(
entry.tape, parents_map, entry.outputs, /* passes */ 3);
}
// This is a good point to do more optimizations, e.g. kernel fusion to
// generate new primitives. The tape needs to be updated accordingly
// Kernel fusion to generate Compiled primitives. The tape and
// new outputs must be updated accordingly
if (compile_mode() != CompileMode::no_fuse) {
compile_fuse(entry.tape, parents_map, entry.inputs, entry.outputs);
}
}
// At this point we must have a tape, now replace the placeholders
@ -422,7 +778,7 @@ void compile_erase(size_t fun_id) {
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun) {
if (detail::compiler_disabled()) {
if (detail::compile_mode() == CompileMode::disabled) {
return fun;
}
auto fun_id = detail::getAddress(fun);
@ -430,11 +786,15 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
}
void disable_compile() {
detail::compiler_disabled() = true;
detail::compile_mode() = CompileMode::disabled;
}
void enable_compile() {
detail::compiler_disabled() = false;
detail::compile_mode() = CompileMode::enabled;
}
void set_compile_mode(CompileMode mode) {
detail::compile_mode() = mode;
}
} // namespace mlx::core

28
mlx/compile.h Normal file
View File

@ -0,0 +1,28 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
enum class CompileMode { disabled, no_simplify, no_fuse, enabled };
// Compile takes a function and returns a new function
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun);
/** Globally disable compilation.
* Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
* be used to disable compilation.
*/
void disable_compile();
/** Globally enable compilation.
* This will override the environment variable ``MLX_DISABLE_COMPILE``.
*/
void enable_compile();
/** Set the compiler mode to the given value. */
void set_compile_mode(CompileMode mode);
} // namespace mlx::core

View File

@ -4,6 +4,7 @@
#include "mlx/array.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/compile.h"
#include "mlx/device.h"
#include "mlx/fft.h"
#include "mlx/io.h"

View File

@ -2,6 +2,8 @@
#pragma once
#include <unordered_set>
#include "mlx/array.h"
#include "mlx/device.h"
#include "mlx/io/load.h"
@ -451,6 +453,46 @@ class Ceil : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class Compiled : public Primitive {
public:
/*
* The inputs, outputs and tape are either tracers or constants.
* - The tape should not contain the inputs, but it should contain the
* outputs.
* - The tape should also have only one array per primitive for multi-output
* primitives.
* - The constant_ids contains ids of arrays in the input list that are safe
* to treat as scalar constants.
*/
explicit Compiled(
Stream stream,
std::vector<array> inputs,
std::vector<array> outputs,
std::vector<array> tape,
std::unordered_set<uintptr_t> constant_ids);
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_VMAP()
DEFINE_GRADS()
void print(std::ostream& os) override;
bool is_equivalent(const Primitive& other) const override;
private:
const std::vector<array> inputs_;
const std::vector<array> outputs_;
const std::vector<array> tape_;
const std::unordered_set<uintptr_t> constant_ids_;
void eval(const std::vector<array>& inputs, std::vector<array>& out);
};
class Concatenate : public UnaryPrimitive {
public:
explicit Concatenate(Stream stream, int axis)

View File

@ -6,21 +6,6 @@
namespace mlx::core {
// Compile takes a function and returns a new function
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun);
/** Globally disable compilation.
* Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
* be used to disable compilation.
*/
void disable_compile();
/** Globally enable compilation.
* This will override the environment variable ``MLX_DISABLE_COMPILE``.
*/
void enable_compile();
void eval(const std::vector<array>& outputs);
template <typename... Arrays>

View File

@ -7,6 +7,7 @@
#include <sstream>
#include "mlx/array.h"
#include "mlx/compile.h"
#include "mlx/graph_utils.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"

View File

@ -190,6 +190,117 @@ class TestCompile(mlx_tests.MLXTestCase):
n_enable_compiled = count_prims(cfun(x))
self.assertEqual(n_compiled, n_enable_compiled)
def test_compile_two_input_grad(self):
def loss(w, x):
y = x * w
return (y * mx.exp(y)).sum()
x = mx.array([1.0, 0.5, 2.0, -0.5])
w = mx.array([-1.0, 0.3, 1.0, -0.9])
expected_grad = mx.grad(loss)(w, x)
compiled_grad = mx.compile(mx.grad(loss))(w, x)
self.assertTrue(mx.allclose(expected_grad, compiled_grad))
def test_vmap_compiled(self):
def simple_unary(x):
return -mx.exp(x)
x = mx.array([[1.0, 2.0], [2.0, 3.0]])
expected_out = mx.vmap(simple_unary)(x)
out = mx.vmap(mx.compile(simple_unary))(x)
self.assertTrue(mx.allclose(expected_out, out))
def simple_binary(x, y):
return mx.abs(mx.exp(x + y) + y)
x = mx.array([[1.0, -3.0], [0.5, -0.5]])
y = mx.array([[2.0, -1.0], [0.25, -0.25]])
expected_out = mx.vmap(simple_binary)(x, y)
out = mx.vmap(mx.compile(simple_binary))(x, y)
self.assertTrue(mx.allclose(expected_out, out))
expected_out = mx.vmap(simple_binary, in_axes=(0, 1))(x, y)
out = mx.vmap(mx.compile(simple_binary), in_axes=(0, 1))(x, y)
self.assertTrue(mx.allclose(expected_out, out))
y = mx.array([0.25, -0.25])
expected_out = mx.vmap(simple_binary, in_axes=(0, None))(x, y)
out = mx.vmap(mx.compile(simple_binary), in_axes=(0, None))(x, y)
self.assertTrue(mx.allclose(expected_out, out))
def simple_unary_outer(x):
x = mx.abs(x)
@mx.compile
def simple_unary_inner(z):
return -mx.exp(x)
return simple_unary_inner(x)
expected_out = -mx.exp(mx.abs(x))
out = mx.vmap(simple_unary_outer)(x)
self.assertTrue(mx.allclose(expected_out, out))
def test_vjp_vjp_compiled(self):
def simple_unary(x):
return -mx.exp(x)
x = mx.array([[1.0, 2.0], [2.0, 3.0]])
y = mx.array([[1.0, 1.0], [1.0, 1.0]])
expected_out, expected_vjp_out = mx.vjp(simple_unary, (x,), (y,))
out, vjp_out = mx.vjp(mx.compile(simple_unary), (x,), (y,))
self.assertTrue(mx.allclose(expected_vjp_out[0], vjp_out[0]))
self.assertTrue(mx.allclose(expected_out[0], out[0]))
expected_out, expected_jvp_out = mx.jvp(simple_unary, (x,), (y,))
out, jvp_out = mx.jvp(mx.compile(simple_unary), (x,), (y,))
self.assertTrue(mx.allclose(expected_jvp_out[0], jvp_out[0]))
self.assertTrue(mx.allclose(expected_out[0], out[0]))
def simple_binary(x, y):
return mx.abs(mx.exp(x + y) + y)
x = mx.array([[1.0, -3.0], [0.5, -0.5]])
y = mx.array([[2.0, -1.0], [0.25, -0.25]])
cotans = mx.ones_like(x)
expected_out, expected_vjp_out = mx.vjp(simple_binary, (x, y), (cotans,))
out, vjp_out = mx.vjp(mx.compile(simple_binary), (x, y), (cotans,))
self.assertTrue(mx.allclose(expected_out[0], out[0]))
self.assertTrue(mx.allclose(expected_vjp_out[0], vjp_out[0]))
self.assertTrue(mx.allclose(expected_vjp_out[1], vjp_out[1]))
tans = (mx.ones_like(x), mx.ones_like(y))
expected_out, expected_jvp_out = mx.jvp(simple_binary, (x, y), tans)
out, jvp_out = mx.jvp(mx.compile(simple_binary), (x, y), tans)
self.assertTrue(mx.allclose(expected_jvp_out[0], jvp_out[0]))
self.assertTrue(mx.allclose(expected_out[0], out[0]))
def test_transform_over_eval_compiled(self):
def outer(x):
y = mx.exp(mx.abs(x))
mx.eval(y)
return y.sum()
x = mx.array([2.0, -1.0, 0.5])
dfdx = mx.grad(outer)(x)
@mx.compile
def simple_unary(x):
return mx.exp(mx.abs(x))
def outer(x):
y = simple_unary(x)
mx.eval(y)
return y.sum()
cdfdx = mx.grad(outer)(x)
self.assertTrue(mx.allclose(dfdx, cdfdx))
if __name__ == "__main__":
unittest.main()

View File

@ -3,6 +3,7 @@
#include "doctest/doctest.h"
#include "mlx/mlx.h"
#include "mlx/primitives.h"
using namespace mlx::core;
@ -120,6 +121,7 @@ auto max_scalars(const std::vector<array>&) {
};
TEST_CASE("test simplify scalars") {
set_compile_mode(CompileMode::no_fuse);
{
auto cfun = compile(add_scalars);
auto out = cfun({});
@ -136,6 +138,7 @@ TEST_CASE("test simplify scalars") {
auto d = out[2];
CHECK(b.inputs()[1].id() == c.inputs()[1].id());
}
set_compile_mode(CompileMode::enabled);
}
auto exp_two(const std::vector<array>& inputs) {
@ -144,9 +147,11 @@ auto exp_two(const std::vector<array>& inputs) {
};
TEST_CASE("test simplify") {
set_compile_mode(CompileMode::no_fuse);
auto a = array({1.0f, 2.0f});
auto b = compile(exp_two)({a})[0];
CHECK(b.inputs()[0].id() == b.inputs()[1].id());
set_compile_mode(CompileMode::enabled);
}
auto add_diff(const std::vector<array>& inputs) {
@ -155,9 +160,11 @@ auto add_diff(const std::vector<array>& inputs) {
};
TEST_CASE("test no simplify") {
set_compile_mode(CompileMode::no_fuse);
auto a = array({1.0f, 2.0f});
auto b = compile(add_diff)({a})[0];
CHECK(b.inputs()[0].id() != b.inputs()[1].id());
set_compile_mode(CompileMode::enabled);
}
auto multi_one(const std::vector<array>&) {
@ -187,6 +194,7 @@ auto multi_three(const std::vector<array>&) {
}
TEST_CASE("test simplify multi output") {
set_compile_mode(CompileMode::no_fuse);
{
auto out = compile(multi_one)({});
auto e = out[0];
@ -210,4 +218,372 @@ TEST_CASE("test simplify multi output") {
CHECK_EQ(e.inputs()[0].id(), e.inputs()[2].id());
CHECK_EQ(e.inputs()[1].id(), e.inputs()[3].id());
}
set_compile_mode(CompileMode::enabled);
}
// No fusion
auto unary_fused_0(const std::vector<array>& inputs) {
return std::vector<array>{exp(inputs[0])};
}
// All compilable
auto unary_fused_1(const std::vector<array>& inputs) {
return std::vector<array>{abs(negative(exp(inputs[0])))};
}
auto unary_fused_1_copy(const std::vector<array>& inputs) {
return std::vector<array>{abs(negative(exp(inputs[0])))};
}
auto unary_fused_1_diff(const std::vector<array>& inputs) {
return std::vector<array>{abs(exp(negative(inputs[0])))};
}
// Output into un-compilable primitive
auto unary_fused_2(const std::vector<array>& inputs) {
return std::vector<array>{sum(abs(negative(exp(inputs[0]))), true)};
}
// Input from un-compilable primitive
auto unary_fused_3(const std::vector<array>& inputs) {
return std::vector<array>{exp(abs(negative(sum(inputs[0], true))))};
}
TEST_CASE("test compile unary fused") {
// NB: some of these tests are brittle and may need to be
// updated if we change compile conditions
{
auto cfun = compile(unary_fused_0);
auto x = array(2.0);
auto out = cfun({x})[0];
auto& p = out.primitive();
CHECK_EQ(typeid(p), typeid(Exp));
CHECK_EQ(out.inputs()[0].id(), x.id());
}
{
auto cfun = compile(unary_fused_1);
auto x = array(2.0);
auto out = cfun({x})[0];
auto& p = out.primitive();
CHECK_EQ(typeid(p), typeid(Compiled));
CHECK_EQ(out.inputs()[0].id(), x.id());
auto expected_out = unary_fused_1({array(2.0)})[0];
CHECK_EQ(out.item<float>(), expected_out.item<float>());
}
{
auto cfun = compile(unary_fused_2);
auto x = array({1.0, 2.0});
auto out = cfun({x});
CHECK_EQ(out.size(), 1);
auto& p = out[0].primitive();
// NB: this test is brittle, will need to update
// it if we change compile conditions
CHECK_EQ(typeid(p), typeid(Reduce));
auto cout = out[0].inputs()[0];
auto& cp = cout.primitive();
CHECK_EQ(typeid(cp), typeid(Compiled));
CHECK_EQ(cout.inputs()[0].id(), x.id());
}
{
auto cfun = compile(unary_fused_3);
auto x = array({1.0, 2.0});
auto out = cfun({x});
auto& p = out[0].primitive();
CHECK_EQ(typeid(p), typeid(Compiled));
auto sout = out[0].inputs()[0];
CHECK_EQ(out[0].inputs().size(), 1);
auto& sp = sout.primitive();
CHECK_EQ(typeid(sp), typeid(Reduce));
CHECK_EQ(sout.inputs()[0].id(), x.id());
}
// Is equivalent works
{
auto out1 = compile(unary_fused_1)({array(1.0)});
auto out2 = compile(unary_fused_1_copy)({array(1.0)});
CHECK(out1[0].primitive().is_equivalent(out2[0].primitive()));
auto out3 = compile(unary_fused_1_diff)({array(1.0)});
CHECK(!out1[0].primitive().is_equivalent(out3[0].primitive()));
}
}
// All compilable
auto binary_fused_0(const std::vector<array>& inputs) {
return std::vector<array>{inputs[0] + inputs[1]};
}
// Binary into unary
auto binary_fused_1(const std::vector<array>& inputs) {
return std::vector<array>{abs(inputs[0] + inputs[1])};
}
// Binary into binary
auto binary_fused_2(const std::vector<array>& inputs) {
auto x = inputs[0] + inputs[1];
return std::vector<array>{x + inputs[0]};
}
// Binary into unary into un-compilable
auto binary_fused_3(const std::vector<array>& inputs) {
return std::vector<array>{sum(abs(inputs[0] + inputs[1]), true)};
}
TEST_CASE("test compile binary fused") {
{
auto cfun = compile(binary_fused_0);
auto x = array(2.0);
auto y = array(2.0);
auto out = cfun({x, y})[0];
auto& p = out.primitive();
CHECK_EQ(typeid(p), typeid(Add));
CHECK_EQ(out.inputs()[0].id(), x.id());
}
{
auto cfun = compile(binary_fused_1);
auto x = array(2.0);
auto y = array(2.0);
auto out = cfun({x, y})[0];
auto& p = out.primitive();
CHECK_EQ(typeid(p), typeid(Compiled));
CHECK_EQ(out.inputs()[0].id(), x.id());
CHECK_EQ(out.inputs()[1].id(), y.id());
auto expected_out = binary_fused_1({x, y})[0];
CHECK_EQ(out.item<float>(), expected_out.item<float>());
}
{
auto cfun = compile(binary_fused_2);
auto x = array(2.0);
auto y = array(2.0);
auto out = cfun({x, y})[0];
auto& p = out.primitive();
CHECK_EQ(typeid(p), typeid(Compiled));
CHECK_EQ(out.inputs()[0].id(), x.id());
CHECK_EQ(out.inputs()[1].id(), y.id());
}
{
auto cfun = compile(binary_fused_3);
auto x = array({1.0, 2.0});
auto y = array({1.0, 2.0});
auto out = cfun({x, y})[0];
auto& p = out.primitive();
CHECK_EQ(typeid(p), typeid(Reduce));
auto cout = out.inputs()[0];
auto& cp = cout.primitive();
CHECK_EQ(typeid(cp), typeid(Compiled));
CHECK_EQ(cout.inputs()[0].id(), x.id());
CHECK_EQ(cout.inputs()[1].id(), y.id());
}
}
auto gelu_1(const std::vector<array>& inputs) {
auto& x = inputs[0];
auto out = x * (1.0f + erf(x / M_SQRT2)) / 2.0f;
return std::vector<array>{out};
}
TEST_CASE("test compile gelu") {
{
auto cfun = compile(gelu_1);
auto x = array(1.0);
auto out = cfun({x})[0];
auto& p = out.primitive();
CHECK_EQ(typeid(p), typeid(Compiled));
CHECK_EQ(out.inputs().size(), 4);
for (auto& in : out.inputs()) {
CHECK(in.inputs().empty());
}
auto expected_out = gelu_1({x})[0];
CHECK(allclose(out, expected_out).item<bool>());
}
{
auto cfun = compile(gelu_1);
auto x = array({1.0, 0.5});
auto out = cfun({x})[0];
auto& p = out.primitive();
CHECK_EQ(typeid(p), typeid(Compiled));
CHECK_EQ(out.inputs().size(), 4);
for (auto& in : out.inputs()) {
CHECK(in.inputs().empty());
}
auto expected_out = gelu_1({x})[0];
CHECK(allclose(out, expected_out).item<bool>());
}
}
// Uncompilable input outside fused tape
auto unary_with_two_outputs(const std::vector<array>& inputs) {
auto x = exp(inputs[0]);
return std::vector<array>{exp(x), sum(x, true)};
}
auto uncompilable_inputs(const std::vector<array>& inputs) {
auto& x = inputs[0];
auto& y = inputs[1];
return std::vector<array>{x * abs(exp(y)), sum(x, true)};
}
auto uncompilable_inputs_order_matters(const std::vector<array>& inputs) {
auto& x = inputs[0];
auto& y = inputs[1];
return std::vector<array>{x / abs(exp(y)), sum(x, true)};
}
TEST_CASE("test compile tape with outside parents") {
{
auto cfun = compile(unary_with_two_outputs);
auto x = array({2.0, 2.0});
auto out = cfun({x});
auto& p1 = out[0].primitive();
CHECK_EQ(typeid(p1), typeid(Exp));
auto& p2 = out[1].primitive();
CHECK_EQ(typeid(p2), typeid(Reduce));
}
{
auto cfun = compile(uncompilable_inputs);
auto x = array({2.0, 2.0});
auto y = array({1.6, 0.6});
auto outs = cfun({x, y});
auto& p1 = outs[0].primitive();
CHECK_EQ(typeid(p1), typeid(Compiled));
auto& p2 = outs[1].primitive();
CHECK_EQ(typeid(p2), typeid(Reduce));
CHECK_EQ(outs[0].inputs().size(), 2);
auto expected_outs = uncompilable_inputs({x, y});
CHECK(allclose(outs[0], expected_outs[0]).item<bool>());
CHECK(allclose(outs[1], expected_outs[1]).item<bool>());
}
{
auto cfun = compile(uncompilable_inputs_order_matters);
auto x = array({2.0, 2.0});
auto y = array({1.6, 0.6});
auto outs = cfun({x, y});
auto& p1 = outs[0].primitive();
CHECK_EQ(typeid(p1), typeid(Compiled));
auto& p2 = outs[1].primitive();
CHECK_EQ(typeid(p2), typeid(Reduce));
CHECK_EQ(outs[0].inputs().size(), 2);
auto expected_outs = uncompilable_inputs_order_matters({x, y});
CHECK(allclose(outs[0], expected_outs[0]).item<bool>());
CHECK(allclose(outs[1], expected_outs[1]).item<bool>());
}
}
auto compile_accross_streams(const std::vector<array>& inputs) {
auto s2 = new_stream(default_device());
auto x = exp(abs(inputs[0]));
auto y = exp(abs(x, s2), s2);
return std::vector<array>{y};
}
TEST_CASE("test compile accross streams") {
auto cfun = compile(compile_accross_streams);
auto x = array({2.0f});
auto out = cfun({x})[0];
auto& p1 = out.primitive();
CHECK_EQ(typeid(p1), typeid(Compiled));
CHECK_EQ(out.inputs().size(), 1);
auto child = out.inputs()[0];
auto& p2 = child.primitive();
CHECK_EQ(typeid(p2), typeid(Compiled));
CHECK_EQ(child.inputs()[0].id(), x.id());
}
auto unary_compile_outputs(const std::vector<array>& inputs) {
auto x = abs(inputs[0]);
auto y = square(x);
return std::vector<array>{x, y};
}
auto binary_compile_outputs(const std::vector<array>& inputs) {
auto x = inputs[0];
auto y = inputs[1];
x = x + y;
y = x + y;
return std::vector<array>{x, y};
}
TEST_CASE("test compile internal output") {
{
auto cfun = compile(unary_compile_outputs);
auto x = array({3, -2});
auto outs = cfun({x});
auto& p1 = outs[0].primitive();
CHECK_EQ(typeid(p1), typeid(Compiled));
auto& p2 = outs[1].primitive();
CHECK_EQ(typeid(p2), typeid(Compiled));
CHECK_EQ(outs[0].siblings()[0].id(), outs[1].id());
auto expected_outs = unary_compile_outputs({x});
CHECK(array_equal(outs[0], expected_outs[0]).item<bool>());
CHECK(array_equal(outs[1], expected_outs[1]).item<bool>());
}
{
auto cfun = compile(binary_compile_outputs);
auto x = array({3, -2});
auto y = array({1, -1});
auto outs = cfun({x, y});
auto& p1 = outs[0].primitive();
CHECK_EQ(typeid(p1), typeid(Compiled));
auto& p2 = outs[1].primitive();
CHECK_EQ(typeid(p2), typeid(Compiled));
auto expected_outs = binary_compile_outputs({x, y});
CHECK(array_equal(outs[0], expected_outs[0]).item<bool>());
CHECK(array_equal(outs[1], expected_outs[1]).item<bool>());
}
}
auto deep_unary_compile(const std::vector<array>& inputs) {
auto x = inputs[0];
for (int i = 0; i < 10; ++i) {
x = cos(sin(x));
}
return std::vector<array>{x};
}
TEST_CASE("test compile deep graph") {
auto cfun = compile(deep_unary_compile);
auto x = array({3.0f, -2.0f});
auto out = cfun({x})[0];
auto expected_out = deep_unary_compile({x})[0];
CHECK(allclose(out, expected_out).item<bool>());
}
auto repeat_input_to_compiled(const std::vector<array>& inputs) {
auto x = abs(exp(inputs[0]));
auto y = abs(exp(sum(x)));
return std::vector<array>{x + y};
}
TEST_CASE("test compile repeat input") {
auto cfun = compile(repeat_input_to_compiled);
auto x = array({3.0f, -2.0f});
auto out = cfun({x})[0];
auto expected_out = repeat_input_to_compiled({x})[0];
CHECK(allclose(out, expected_out).item<bool>());
}