diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index b24a2bc74..6cd851111 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -33,10 +33,12 @@ DEFAULT(ArgSort) DEFAULT(AsStrided) DEFAULT(Broadcast) DEFAULT(Ceil) +DEFAULT_MULTI(Compiled) DEFAULT(Concatenate) DEFAULT(Copy) DEFAULT_MULTI(CustomVJP) DEFAULT_MULTI(Depends) +DEFAULT_MULTI(DivMod) DEFAULT(Equal) DEFAULT(Erf) DEFAULT(ErfInv) @@ -57,6 +59,7 @@ DEFAULT(Minimum) DEFAULT(NotEqual) DEFAULT(Pad) DEFAULT(Partition) +DEFAULT_MULTI(QRF) DEFAULT(RandomBits) DEFAULT(Reshape) DEFAULT(Round) @@ -68,8 +71,6 @@ DEFAULT_MULTI(Split) DEFAULT(Sort) DEFAULT(StopGradient) DEFAULT(Transpose) -DEFAULT_MULTI(DivMod) -DEFAULT_MULTI(QRF) void Abs::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 25563f66b..0263dff9b 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -3,6 +3,7 @@ target_sources( PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp new file mode 100644 index 000000000..149556530 --- /dev/null +++ b/mlx/backend/common/compiled.cpp @@ -0,0 +1,59 @@ +// Copyright © 2023-2024 Apple Inc. + +#include + +#include "mlx/primitives.h" + +namespace mlx::core { + +// Build the real tape +std::pair, std::vector> trace_to_real( + const std::vector& trace_tape, + const std::vector& trace_inputs, + const std::vector& trace_outputs, + const std::vector& inputs) { + std::unordered_map trace_to_real; + for (int i = 0; i < inputs.size(); ++i) { + trace_to_real.insert({trace_inputs[i].id(), inputs[i]}); + } + std::queue tape; + for (auto& a : trace_tape) { + // Find real inputs + std::vector real_inputs; + for (auto& in : a.inputs()) { + real_inputs.push_back(trace_to_real.at(in.id())); + } + tape.push( + array(a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs))); + trace_to_real.insert({a.id(), tape.back()}); + } + + std::vector outputs; + for (auto& o : trace_outputs) { + outputs.push_back(trace_to_real.at(o.id())); + } + return {tape, outputs}; +} + +void Compiled::eval( + const std::vector& inputs, + std::vector& outputs) { + // Make the a real tape from the tracers + auto [tape, real_outputs] = trace_to_real(tape_, inputs_, outputs_, inputs); + + // Run the tape + while (!tape.empty()) { + auto a = std::move(tape.front()); + tape.pop(); + auto outputs = a.outputs(); + a.primitive().eval_cpu(a.inputs(), outputs); + a.detach(); + } + + // Copy results into outputs + for (int o = 0; o < real_outputs.size(); ++o) { + outputs[o].copy_shared_buffer(real_outputs[o]); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index d9023fe9b..06c9db58d 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -41,7 +41,9 @@ DEFAULT(ArgSort) DEFAULT(AsType) DEFAULT(AsStrided) DEFAULT(Broadcast) +DEFAULT_MULTI(DivMod) DEFAULT(Ceil) +DEFAULT_MULTI(Compiled) DEFAULT(Concatenate) DEFAULT(Convolution) DEFAULT(Copy) @@ -78,6 +80,7 @@ DEFAULT(NotEqual) DEFAULT(Pad) DEFAULT(Partition) DEFAULT(Power) +DEFAULT_MULTI(QRF) DEFAULT(QuantizedMatmul) DEFAULT(RandomBits) DEFAULT(Reduce) @@ -100,8 +103,6 @@ DEFAULT(Subtract) DEFAULT(Tan) DEFAULT(Tanh) DEFAULT(Transpose) -DEFAULT_MULTI(DivMod) -DEFAULT_MULTI(QRF) namespace { diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index e12402c18..fd1a47f01 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -2,6 +2,7 @@ target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp new file mode 100644 index 000000000..4ce9b0c85 --- /dev/null +++ b/mlx/backend/metal/compiled.cpp @@ -0,0 +1,44 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/backend/metal/device.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +void Compiled::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + // Just a fall-back to the original tape for now + std::unordered_map trace_to_real; + for (int i = 0; i < inputs.size(); ++i) { + trace_to_real.insert({inputs_[i].id(), inputs[i]}); + } + for (int i = 0; i < outputs.size(); ++i) { + trace_to_real.insert({outputs_[i].id(), outputs[i]}); + } + + for (auto& a : tape_) { + std::vector p_inputs; + for (auto& in : a.inputs()) { + p_inputs.push_back(trace_to_real.at(in.id())); + } + // If a is an output get it from the map, otherwise create it + // NB this is safe as long as no multi-output sub primitves are allowed + // in Compiled + std::vector p_outputs; + if (auto it = trace_to_real.find(a.id()); it != trace_to_real.end()) { + p_outputs.push_back(it->second); + } else { + p_outputs.push_back(array(a.shape(), a.dtype(), a.primitive_ptr(), {})); + trace_to_real.insert({a.id(), p_outputs[0]}); + } + a.primitive().eval_gpu(p_inputs, p_outputs); + } + auto& s = stream(); + auto& d = metal::device(s.device); + auto command_buffer = d.get_command_buffer(s.index); + command_buffer->addCompletedHandler( + [trace_to_real](MTL::CommandBuffer*) mutable {}); +} + +} // namespace mlx::core diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 477ceff20..dd4edc2ed 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -32,6 +32,7 @@ NO_GPU(AsType) NO_GPU(AsStrided) NO_GPU(Broadcast) NO_GPU(Ceil) +NO_GPU_MULTI(Compiled) NO_GPU(Concatenate) NO_GPU(Convolution) NO_GPU(Copy) @@ -40,6 +41,7 @@ NO_GPU(Cosh) NO_GPU_MULTI(CustomVJP) NO_GPU_MULTI(Depends) NO_GPU(Divide) +NO_GPU_MULTI(DivMod) NO_GPU(Remainder) NO_GPU(Equal) NO_GPU(Erf) @@ -69,6 +71,7 @@ NO_GPU(NotEqual) NO_GPU(Pad) NO_GPU(Partition) NO_GPU(Power) +NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) NO_GPU(RandomBits) NO_GPU(Reduce) @@ -91,6 +94,5 @@ NO_GPU(Subtract) NO_GPU(Tan) NO_GPU(Tanh) NO_GPU(Transpose) -NO_GPU_MULTI(DivMod) -NO_GPU_MULTI(QRF) + } // namespace mlx::core diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 19a0d368f..743af64ed 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -1,36 +1,198 @@ // Copyright © 2023-2024 Apple Inc. + #include #include #include #include #include "mlx/allocator.h" +#include "mlx/compile.h" #include "mlx/primitives.h" #include "mlx/transforms.h" #include "mlx/transforms_impl.h" namespace mlx::core { -namespace detail { +constexpr int max_compile_depth = 6; -bool& compiler_disabled() { - auto get_val = []() { - if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILE")) { - return true; - } else { - return false; - } - }; - static bool compiler_disabled_ = get_val(); - return compiler_disabled_; +bool is_unary(const Primitive& p) { + return ( + typeid(p) == typeid(Abs) || typeid(p) == typeid(ArcCos) || + typeid(p) == typeid(ArcCosh) || typeid(p) == typeid(ArcSin) || + typeid(p) == typeid(ArcSinh) || typeid(p) == typeid(ArcTan) || + typeid(p) == typeid(ArcTanh) || typeid(p) == typeid(AsType) || + typeid(p) == typeid(Ceil) || typeid(p) == typeid(Cos) || + typeid(p) == typeid(Cosh) || typeid(p) == typeid(Remainder) || + typeid(p) == typeid(Erf) || typeid(p) == typeid(ErfInv) || + typeid(p) == typeid(Exp) || typeid(p) == typeid(Floor) || + typeid(p) == typeid(Log) || typeid(p) == typeid(Log1p) || + typeid(p) == typeid(LogicalNot) || typeid(p) == typeid(Negative) || + typeid(p) == typeid(Round) || typeid(p) == typeid(Sigmoid) || + typeid(p) == typeid(Sign) || typeid(p) == typeid(Sin) || + typeid(p) == typeid(Sinh) || typeid(p) == typeid(Square) || + typeid(p) == typeid(Sqrt) || typeid(p) == typeid(Tan) || + typeid(p) == typeid(Tanh)); } -#define MAX_OPS_PER_BUFFER max_ops_per_buffer() +bool is_binary(const Primitive& p) { + return ( + typeid(p) == typeid(Add) || typeid(p) == typeid(Divide) || + typeid(p) == typeid(Equal) || typeid(p) == typeid(Greater) || + typeid(p) == typeid(GreaterEqual) || typeid(p) == typeid(Less) || + typeid(p) == typeid(LessEqual) || typeid(p) == typeid(LogicalNot) || + typeid(p) == typeid(LogicalAnd) || typeid(p) == typeid(LogicalOr) || + typeid(p) == typeid(LogAddExp) || typeid(p) == typeid(Maximum) || + typeid(p) == typeid(Minimum) || typeid(p) == typeid(Multiply) || + typeid(p) == typeid(NotEqual) || typeid(p) == typeid(Power) || + typeid(p) == typeid(Subtract)); +} + +bool is_broadcast(const Primitive& p) { + return typeid(p) == typeid(Broadcast); +} + +bool is_noop(const Primitive& p) { + return typeid(p) == typeid(Copy) || typeid(p) == typeid(StopGradient); +} + +bool is_fusable(const Primitive& p) { + return is_unary(p) || is_binary(p) || is_broadcast(p) || is_noop(p); +} + +namespace detail { + +std::vector compile_replace( + const std::vector& tape, + const std::vector& trace_inputs, + const std::vector& trace_outputs, + const std::vector& inputs); + +} // namespace detail + +Compiled::Compiled( + Stream stream, + std::vector inputs, + std::vector outputs, + std::vector tape, + std::unordered_set constant_ids) + : Primitive(stream), + inputs_(std::move(inputs)), + outputs_(std::move(outputs)), + tape_(std::move(tape)), + constant_ids_(std::move(constant_ids)) {} + +std::vector Compiled::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) { + auto fun = [this](const std::vector& inputs) { + return detail::compile_replace(tape_, inputs_, outputs_, inputs); + }; + + auto [_, vjps] = mlx::core::vjp(fun, primals, cotangents); + std::vector vjp_outs; + for (int i = 0, j = 0; i < vjps.size(); ++i) { + if (i < argnums.size() && i == argnums[j]) { + vjp_outs.push_back(vjps[i]); + j++; + } + } + return vjp_outs; +} + +std::vector Compiled::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + auto fun = [this](const std::vector& inputs) { + return detail::compile_replace(tape_, inputs_, outputs_, inputs); + }; + + auto [_, jvps] = mlx::core::jvp(fun, primals, tangents); + std::vector jvp_outs; + for (int i = 0, j = 0; i < jvps.size(); ++i) { + if (i < argnums.size() && i == argnums[j]) { + jvp_outs.push_back(jvps[i]); + j++; + } + } + return jvp_outs; +} + +std::pair, std::vector> Compiled::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto fun = [this](const std::vector& inputs) { + return detail::compile_replace(tape_, inputs_, outputs_, inputs); + }; + + auto outputs = mlx::core::vmap(fun, axes)(inputs); + auto out_axes = std::vector(outputs.size(), 0); + return {outputs, out_axes}; +} + +bool Compiled::is_equivalent(const Primitive& other) const { + const Compiled& a_other = static_cast(other); + return std::equal( + tape_.begin(), + tape_.end(), + a_other.tape_.begin(), + a_other.tape_.end(), + [](const array& a1, const array& a2) { + auto& p1 = a1.primitive(); + auto& p2 = a2.primitive(); + return typeid(p1) == typeid(p2) && p1.is_equivalent(p2); + }); +} + +void Compiled::print(std::ostream& os) { + os << "Compiled"; + for (auto& a : tape_) { + a.primitive().print(os); + } +} + +namespace detail { + +CompileMode& compile_mode() { + auto get_val = []() { + if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILE")) { + return CompileMode::disabled; + } else { + return CompileMode::enabled; + } + }; + static CompileMode compile_mode_ = get_val(); + return compile_mode_; +} using CompileFn = std::function(const std::vector&)>; using ParentsMap = std::unordered_map>>; +// Helper that merges two arrays in the graph by setting the parents of the +// source to point to the destination +void merge(array& dst, array& src, ParentsMap& parents_map) { + // Canonicalize the order of the primitives outputs + auto sources = src.outputs(); + auto dests = dst.outputs(); + // For each src parent, point it to the corresponding dst + for (int i = 0; i < sources.size(); ++i) { + auto src_parents = parents_map.find(sources[i].id()); + if (src_parents == parents_map.end()) { + continue; + } + auto& pairs = parents_map[dests[i].id()]; + for (auto& parent : src_parents->second) { + parent.first.inputs()[parent.second] = dests[i]; + pairs.push_back(parent); + } + // Remove the source from the map to avoid fusing with it again + parents_map.erase(src_parents); + } +}; + template size_t getAddress(std::function f) { typedef T(fnType)(U...); @@ -59,9 +221,10 @@ struct CompilerCache { auto is_match = [](const std::vector& in1, const std::vector& in2) { if (in1.size() != in2.size()) { - throw std::runtime_error( - "[compiler] Got different number of inputs to function," - " this should never happen."); + std::ostringstream msg; + msg << "[compiler] Unexpected number of inputs to compiled function:" + << " expected " << in2.size() << " got " << in1.size() << "."; + throw std::invalid_argument(msg.str()); } for (int i = 0; i < in1.size(); ++i) { if (in1[i].shape() != in2[i].shape()) { @@ -205,28 +368,6 @@ void compile_simplify( } } - // Helper that fuses two arrays in the graph by setting the parents of the - // source to point to the destination - auto fuse = [&](array& dst, array& src) { - // Canonicalize the order of the primitives outputs - auto sources = src.outputs(); - auto dests = dst.outputs(); - // For each src parent, point it to the corresponding dest - for (int i = 0; i < sources.size(); ++i) { - auto src_parents = parents_map.find(sources[i].id()); - if (src_parents == parents_map.end()) { - continue; - } - auto& pairs = parents_map[dests[i].id()]; - for (auto& parent : src_parents->second) { - parent.first.inputs()[parent.second] = dests[i]; - pairs.push_back(parent); - } - // Remove the source from the map to avoid fusing with it again - parents_map.erase(src_parents); - } - }; - // Depth-1 array equivalence check. auto array_equivalent = [](const array& a, const array& b) { if (!a.has_primitive() || !b.has_primitive()) { @@ -254,33 +395,32 @@ void compile_simplify( return pa.is_equivalent(pb); }; - // Pass 0: fuse scalars + // Merge scalars std::vector new_tape; for (auto& arr : tape) { - // Check if we can fuse scalars + // Check if we can merge scalars if (is_scalar(arr)) { auto scalar = scalars.find(get_scalar_rep(arr)); if (scalar->second.id() != arr.id()) { - fuse(scalar->second, arr); + merge(scalar->second, arr, parents_map); // Don't keep orphaned scalars in the tape continue; } } new_tape.push_back(std::move(arr)); } - tape = std::move(new_tape); std::unordered_set output_set; for (auto& o : outputs) { output_set.insert(o.id()); } - // Pass 1..passes: fuse only keeping non-orphaned arrays in the tape + // Multi-pass merge only keeping non-orphaned arrays in the tape for (int pass = 0; pass < passes; ++pass) { for (auto& arr : tape) { - // Helper to check if we can fuse the parents of the + // Helper to check if we can merge the parents of the // given array - auto maybe_fuse_parents = [&](auto& a) { + auto maybe_merge_parents = [&](auto& a) { auto parents = parents_map.find(a.id()); if (parents != parents_map.end()) { auto N = parents->second.size(); @@ -296,7 +436,7 @@ void compile_simplify( auto& src = parents->second[j].first; auto& dst = parents->second[i].first; if (src.id() != dst.id() && array_equivalent(src, dst)) { - fuse(dst, src); + merge(dst, src, parents_map); mask[j] = true; } } @@ -313,9 +453,9 @@ void compile_simplify( } }; - bool discard = maybe_fuse_parents(arr); + bool discard = maybe_merge_parents(arr); for (auto& s : arr.siblings()) { - discard &= maybe_fuse_parents(s); + discard &= maybe_merge_parents(s); } // If an array and its siblings have no parents, and none of them are // outputs, it is safe to remove it from the tape @@ -327,6 +467,216 @@ void compile_simplify( } } +// Extract sub-graphs of the graph that can be compiled +// and replace them with a Compiled Primitive. +void compile_fuse( + std::vector& tape, + ParentsMap& parents_map, + const std::vector& inputs, + std::vector& outputs) { + // Track outputs to replace with new compiled outputs + std::unordered_map output_map; + for (auto& o : outputs) { + output_map.insert({o.id(), o}); + } + + // Set of inputs to distinguish constants + std::unordered_set input_ids; + for (auto& in : inputs) { + input_ids.insert(in.id()); + } + + // Go through the tape in reverse order and check for fusable sub-graphs + std::vector new_tape; + std::unordered_set global_cache; + for (int i = tape.size() - 1; i >= 0; --i) { + auto& arr = tape[i]; + + // Already compiled + if (global_cache.find(arr.id()) != global_cache.end()) { + continue; + } + + // Two pass recursion: + // First pass: + // - Collect all the primitives which we can fuse with + // - Keeps a cache of fusable primitives which may be added out of + // DAG order. We have to determine if all of a fused primitive's + // outputs are also in the fused section, and this may not be the + // case the first time we visit it. + // Second pass: + // - Collect inputs to the new compiled primitive + // - Add fusable primitives to a tape in the correct order + + std::function recurse; + std::unordered_set cache; + recurse = [&](const array& a, int depth, const Stream& s) { + if (cache.find(a.id()) != cache.end()) { + return; + } + + // Stop fusing if: + // - Depth limit exceeded + // - Constant input + // - Stream mismatch + // - Non fusable primitive + if (depth >= max_compile_depth || !a.has_primitive() || + a.primitive().stream() != s || !is_fusable(a.primitive())) { + return; + } + + bool all_parents_in = true; + if (depth > 0) { + // Guaranteed to have a parent since nested in the + // recursion. + auto& parents = parents_map.at(a.id()); + for (auto& [p, idx] : parents) { + auto in_cache = cache.find(p.id()) != cache.end(); + if (!in_cache) { + all_parents_in = false; + break; + } + } + } + + // Arrays with a mix of parents outside the compilable section + // are not fusable + if (!all_parents_in) { + return; + } + + cache.insert({a.id()}); + + for (auto& in : a.inputs()) { + recurse(in, depth + 1, s); + } + }; + + if (arr.has_primitive()) { + Stream s = arr.primitive().stream(); + recurse(arr, 0, s); + } + + // Not worth fusing a single primitive + if (cache.size() <= 1) { + new_tape.push_back(arr); + continue; + } + + // Recurse a second time to build the tape in the right + // order and collect the inputs + std::unordered_set input_set; + std::vector inputs; + std::vector fused_tape; + std::unordered_set tape_set; + std::function recurse_tape; + recurse_tape = [&](const array& a) { + if (cache.find(a.id()) == cache.end()) { + if (input_set.find(a.id()) == input_set.end()) { + input_set.insert(a.id()); + inputs.push_back(a); + } + return; + } + if (tape_set.find(a.id()) != tape_set.end()) { + return; + } + tape_set.insert(a.id()); + for (auto& in : a.inputs()) { + recurse_tape(in); + } + fused_tape.push_back(a); + }; + recurse_tape(arr); + + std::vector old_outputs; + // Add to global cache and add any global outputs to outputs + // of new primitive + for (int j = 0; j < fused_tape.size() - 1; ++j) { + auto& f = fused_tape[j]; + if (output_map.find(f.id()) != output_map.end()) { + old_outputs.push_back(f); + // Parents are now siblings, update the parent map + auto& pairs = parents_map[f.id()]; + pairs.erase( + std::remove_if( + pairs.begin(), + pairs.end(), + [&](auto& p) { + return cache.find(p.first.id()) != cache.end(); + }), + pairs.end()); + } else { + // Remove inner fused arrays parents from the parents map + // to keep the parents map in a valid state + parents_map.erase(f.id()); + } + global_cache.insert({f.id()}); + } + old_outputs.push_back(arr); + + std::vector> shapes; + std::vector types; + for (auto& o : old_outputs) { + shapes.push_back(o.shape()); + types.push_back(o.dtype()); + } + std::unordered_set constant_ids; + for (auto& in : inputs) { + // Scalar constant + if (in.size() == 1 && !in.has_primitive() && + input_ids.find(in.id()) == input_ids.end()) { + constant_ids.insert(in.id()); + } + } + auto compiled_outputs = array::make_arrays( + shapes, + types, + std::make_shared( + outputs.back().primitive().stream(), + inputs, + old_outputs, + std::move(fused_tape), + std::move(constant_ids)), + inputs); + + // One output per primitive + new_tape.push_back(compiled_outputs.back()); + + // Replace inputs old parents with compiled_outputs + for (int i = 0; i < inputs.size(); ++i) { + auto& pairs = parents_map[inputs[i].id()]; + pairs.erase( + std::remove_if( + pairs.begin(), + pairs.end(), + [&](auto& p) { return cache.find(p.first.id()) != cache.end(); }), + pairs.end()); + for (auto& o : compiled_outputs) { + pairs.push_back({o, i}); + } + } + + // - Update outputs parents to point to compiled outputs + // - Update any overall graph outputs to be compiled outputs + for (int o = 0; o < old_outputs.size(); ++o) { + merge(compiled_outputs[o], old_outputs[o], parents_map); + if (auto it = output_map.find(old_outputs[o].id()); + it != output_map.end()) { + it->second = compiled_outputs[o]; + } + } + } + + std::reverse(new_tape.begin(), new_tape.end()); + tape = std::move(new_tape); + + // Replace output with potentially compiled output + for (auto& o : outputs) { + o = output_map.at(o.id()); + } +} + std::vector compile_replace( const std::vector& tape, const std::vector& trace_inputs, @@ -380,7 +730,7 @@ std::vector compile_replace( std::function(const std::vector&)> compile( const std::function(const std::vector&)>& fun, size_t fun_id) { - if (compiler_disabled()) { + if (compile_mode() == CompileMode::disabled) { return fun; } return [fun, fun_id](const std::vector& inputs) { @@ -402,10 +752,16 @@ std::function(const std::vector&)> compile( compile_dfs(entry.inputs, entry.outputs); // Simplify the tape - compile_simplify(entry.tape, parents_map, entry.outputs, /* passes */ 3); + if (compile_mode() != CompileMode::no_simplify) { + compile_simplify( + entry.tape, parents_map, entry.outputs, /* passes */ 3); + } - // This is a good point to do more optimizations, e.g. kernel fusion to - // generate new primitives. The tape needs to be updated accordingly + // Kernel fusion to generate Compiled primitives. The tape and + // new outputs must be updated accordingly + if (compile_mode() != CompileMode::no_fuse) { + compile_fuse(entry.tape, parents_map, entry.inputs, entry.outputs); + } } // At this point we must have a tape, now replace the placeholders @@ -422,7 +778,7 @@ void compile_erase(size_t fun_id) { std::function(const std::vector&)> compile( const std::function(const std::vector&)>& fun) { - if (detail::compiler_disabled()) { + if (detail::compile_mode() == CompileMode::disabled) { return fun; } auto fun_id = detail::getAddress(fun); @@ -430,11 +786,15 @@ std::function(const std::vector&)> compile( } void disable_compile() { - detail::compiler_disabled() = true; + detail::compile_mode() = CompileMode::disabled; } void enable_compile() { - detail::compiler_disabled() = false; + detail::compile_mode() = CompileMode::enabled; +} + +void set_compile_mode(CompileMode mode) { + detail::compile_mode() = mode; } } // namespace mlx::core diff --git a/mlx/compile.h b/mlx/compile.h new file mode 100644 index 000000000..fb3115d61 --- /dev/null +++ b/mlx/compile.h @@ -0,0 +1,28 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include "mlx/array.h" + +namespace mlx::core { + +enum class CompileMode { disabled, no_simplify, no_fuse, enabled }; + +// Compile takes a function and returns a new function +std::function(const std::vector&)> compile( + const std::function(const std::vector&)>& fun); + +/** Globally disable compilation. + * Setting the environment variable ``MLX_DISABLE_COMPILE`` can also + * be used to disable compilation. + */ +void disable_compile(); + +/** Globally enable compilation. + * This will override the environment variable ``MLX_DISABLE_COMPILE``. + */ +void enable_compile(); + +/** Set the compiler mode to the given value. */ +void set_compile_mode(CompileMode mode); +} // namespace mlx::core diff --git a/mlx/mlx.h b/mlx/mlx.h index a0a281974..7b33faba7 100644 --- a/mlx/mlx.h +++ b/mlx/mlx.h @@ -4,6 +4,7 @@ #include "mlx/array.h" #include "mlx/backend/metal/metal.h" +#include "mlx/compile.h" #include "mlx/device.h" #include "mlx/fft.h" #include "mlx/io.h" diff --git a/mlx/primitives.h b/mlx/primitives.h index 15d6c485e..5bdee12cf 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2,6 +2,8 @@ #pragma once +#include + #include "mlx/array.h" #include "mlx/device.h" #include "mlx/io/load.h" @@ -451,6 +453,46 @@ class Ceil : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; +class Compiled : public Primitive { + public: + /* + * The inputs, outputs and tape are either tracers or constants. + * - The tape should not contain the inputs, but it should contain the + * outputs. + * - The tape should also have only one array per primitive for multi-output + * primitives. + * - The constant_ids contains ids of arrays in the input list that are safe + * to treat as scalar constants. + */ + explicit Compiled( + Stream stream, + std::vector inputs, + std::vector outputs, + std::vector tape, + std::unordered_set constant_ids); + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_VMAP() + DEFINE_GRADS() + + void print(std::ostream& os) override; + + bool is_equivalent(const Primitive& other) const override; + + private: + const std::vector inputs_; + const std::vector outputs_; + const std::vector tape_; + const std::unordered_set constant_ids_; + + void eval(const std::vector& inputs, std::vector& out); +}; + class Concatenate : public UnaryPrimitive { public: explicit Concatenate(Stream stream, int axis) diff --git a/mlx/transforms.h b/mlx/transforms.h index 297571e1d..95fdfcffe 100644 --- a/mlx/transforms.h +++ b/mlx/transforms.h @@ -6,21 +6,6 @@ namespace mlx::core { -// Compile takes a function and returns a new function -std::function(const std::vector&)> compile( - const std::function(const std::vector&)>& fun); - -/** Globally disable compilation. - * Setting the environment variable ``MLX_DISABLE_COMPILE`` can also - * be used to disable compilation. - */ -void disable_compile(); - -/** Globally enable compilation. - * This will override the environment variable ``MLX_DISABLE_COMPILE``. - */ -void enable_compile(); - void eval(const std::vector& outputs); template diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 8937892fb..78f867876 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -7,6 +7,7 @@ #include #include "mlx/array.h" +#include "mlx/compile.h" #include "mlx/graph_utils.h" #include "mlx/transforms.h" #include "mlx/transforms_impl.h" diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 594dd2859..56dff8b3d 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -190,6 +190,117 @@ class TestCompile(mlx_tests.MLXTestCase): n_enable_compiled = count_prims(cfun(x)) self.assertEqual(n_compiled, n_enable_compiled) + def test_compile_two_input_grad(self): + def loss(w, x): + y = x * w + return (y * mx.exp(y)).sum() + + x = mx.array([1.0, 0.5, 2.0, -0.5]) + w = mx.array([-1.0, 0.3, 1.0, -0.9]) + + expected_grad = mx.grad(loss)(w, x) + compiled_grad = mx.compile(mx.grad(loss))(w, x) + self.assertTrue(mx.allclose(expected_grad, compiled_grad)) + + def test_vmap_compiled(self): + def simple_unary(x): + return -mx.exp(x) + + x = mx.array([[1.0, 2.0], [2.0, 3.0]]) + + expected_out = mx.vmap(simple_unary)(x) + out = mx.vmap(mx.compile(simple_unary))(x) + self.assertTrue(mx.allclose(expected_out, out)) + + def simple_binary(x, y): + return mx.abs(mx.exp(x + y) + y) + + x = mx.array([[1.0, -3.0], [0.5, -0.5]]) + y = mx.array([[2.0, -1.0], [0.25, -0.25]]) + + expected_out = mx.vmap(simple_binary)(x, y) + out = mx.vmap(mx.compile(simple_binary))(x, y) + self.assertTrue(mx.allclose(expected_out, out)) + + expected_out = mx.vmap(simple_binary, in_axes=(0, 1))(x, y) + out = mx.vmap(mx.compile(simple_binary), in_axes=(0, 1))(x, y) + self.assertTrue(mx.allclose(expected_out, out)) + + y = mx.array([0.25, -0.25]) + expected_out = mx.vmap(simple_binary, in_axes=(0, None))(x, y) + out = mx.vmap(mx.compile(simple_binary), in_axes=(0, None))(x, y) + self.assertTrue(mx.allclose(expected_out, out)) + + def simple_unary_outer(x): + x = mx.abs(x) + + @mx.compile + def simple_unary_inner(z): + return -mx.exp(x) + + return simple_unary_inner(x) + + expected_out = -mx.exp(mx.abs(x)) + out = mx.vmap(simple_unary_outer)(x) + self.assertTrue(mx.allclose(expected_out, out)) + + def test_vjp_vjp_compiled(self): + def simple_unary(x): + return -mx.exp(x) + + x = mx.array([[1.0, 2.0], [2.0, 3.0]]) + y = mx.array([[1.0, 1.0], [1.0, 1.0]]) + + expected_out, expected_vjp_out = mx.vjp(simple_unary, (x,), (y,)) + out, vjp_out = mx.vjp(mx.compile(simple_unary), (x,), (y,)) + self.assertTrue(mx.allclose(expected_vjp_out[0], vjp_out[0])) + self.assertTrue(mx.allclose(expected_out[0], out[0])) + + expected_out, expected_jvp_out = mx.jvp(simple_unary, (x,), (y,)) + out, jvp_out = mx.jvp(mx.compile(simple_unary), (x,), (y,)) + self.assertTrue(mx.allclose(expected_jvp_out[0], jvp_out[0])) + self.assertTrue(mx.allclose(expected_out[0], out[0])) + + def simple_binary(x, y): + return mx.abs(mx.exp(x + y) + y) + + x = mx.array([[1.0, -3.0], [0.5, -0.5]]) + y = mx.array([[2.0, -1.0], [0.25, -0.25]]) + cotans = mx.ones_like(x) + + expected_out, expected_vjp_out = mx.vjp(simple_binary, (x, y), (cotans,)) + out, vjp_out = mx.vjp(mx.compile(simple_binary), (x, y), (cotans,)) + self.assertTrue(mx.allclose(expected_out[0], out[0])) + self.assertTrue(mx.allclose(expected_vjp_out[0], vjp_out[0])) + self.assertTrue(mx.allclose(expected_vjp_out[1], vjp_out[1])) + + tans = (mx.ones_like(x), mx.ones_like(y)) + expected_out, expected_jvp_out = mx.jvp(simple_binary, (x, y), tans) + out, jvp_out = mx.jvp(mx.compile(simple_binary), (x, y), tans) + self.assertTrue(mx.allclose(expected_jvp_out[0], jvp_out[0])) + self.assertTrue(mx.allclose(expected_out[0], out[0])) + + def test_transform_over_eval_compiled(self): + def outer(x): + y = mx.exp(mx.abs(x)) + mx.eval(y) + return y.sum() + + x = mx.array([2.0, -1.0, 0.5]) + dfdx = mx.grad(outer)(x) + + @mx.compile + def simple_unary(x): + return mx.exp(mx.abs(x)) + + def outer(x): + y = simple_unary(x) + mx.eval(y) + return y.sum() + + cdfdx = mx.grad(outer)(x) + self.assertTrue(mx.allclose(dfdx, cdfdx)) + if __name__ == "__main__": unittest.main() diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index 3d66d6afd..0d96ae23e 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -3,6 +3,7 @@ #include "doctest/doctest.h" #include "mlx/mlx.h" +#include "mlx/primitives.h" using namespace mlx::core; @@ -120,6 +121,7 @@ auto max_scalars(const std::vector&) { }; TEST_CASE("test simplify scalars") { + set_compile_mode(CompileMode::no_fuse); { auto cfun = compile(add_scalars); auto out = cfun({}); @@ -136,6 +138,7 @@ TEST_CASE("test simplify scalars") { auto d = out[2]; CHECK(b.inputs()[1].id() == c.inputs()[1].id()); } + set_compile_mode(CompileMode::enabled); } auto exp_two(const std::vector& inputs) { @@ -144,9 +147,11 @@ auto exp_two(const std::vector& inputs) { }; TEST_CASE("test simplify") { + set_compile_mode(CompileMode::no_fuse); auto a = array({1.0f, 2.0f}); auto b = compile(exp_two)({a})[0]; CHECK(b.inputs()[0].id() == b.inputs()[1].id()); + set_compile_mode(CompileMode::enabled); } auto add_diff(const std::vector& inputs) { @@ -155,9 +160,11 @@ auto add_diff(const std::vector& inputs) { }; TEST_CASE("test no simplify") { + set_compile_mode(CompileMode::no_fuse); auto a = array({1.0f, 2.0f}); auto b = compile(add_diff)({a})[0]; CHECK(b.inputs()[0].id() != b.inputs()[1].id()); + set_compile_mode(CompileMode::enabled); } auto multi_one(const std::vector&) { @@ -187,6 +194,7 @@ auto multi_three(const std::vector&) { } TEST_CASE("test simplify multi output") { + set_compile_mode(CompileMode::no_fuse); { auto out = compile(multi_one)({}); auto e = out[0]; @@ -210,4 +218,372 @@ TEST_CASE("test simplify multi output") { CHECK_EQ(e.inputs()[0].id(), e.inputs()[2].id()); CHECK_EQ(e.inputs()[1].id(), e.inputs()[3].id()); } + set_compile_mode(CompileMode::enabled); +} + +// No fusion +auto unary_fused_0(const std::vector& inputs) { + return std::vector{exp(inputs[0])}; +} + +// All compilable +auto unary_fused_1(const std::vector& inputs) { + return std::vector{abs(negative(exp(inputs[0])))}; +} + +auto unary_fused_1_copy(const std::vector& inputs) { + return std::vector{abs(negative(exp(inputs[0])))}; +} + +auto unary_fused_1_diff(const std::vector& inputs) { + return std::vector{abs(exp(negative(inputs[0])))}; +} + +// Output into un-compilable primitive +auto unary_fused_2(const std::vector& inputs) { + return std::vector{sum(abs(negative(exp(inputs[0]))), true)}; +} + +// Input from un-compilable primitive +auto unary_fused_3(const std::vector& inputs) { + return std::vector{exp(abs(negative(sum(inputs[0], true))))}; +} + +TEST_CASE("test compile unary fused") { + // NB: some of these tests are brittle and may need to be + // updated if we change compile conditions + { + auto cfun = compile(unary_fused_0); + auto x = array(2.0); + auto out = cfun({x})[0]; + + auto& p = out.primitive(); + CHECK_EQ(typeid(p), typeid(Exp)); + CHECK_EQ(out.inputs()[0].id(), x.id()); + } + + { + auto cfun = compile(unary_fused_1); + auto x = array(2.0); + auto out = cfun({x})[0]; + + auto& p = out.primitive(); + CHECK_EQ(typeid(p), typeid(Compiled)); + CHECK_EQ(out.inputs()[0].id(), x.id()); + + auto expected_out = unary_fused_1({array(2.0)})[0]; + CHECK_EQ(out.item(), expected_out.item()); + } + + { + auto cfun = compile(unary_fused_2); + auto x = array({1.0, 2.0}); + auto out = cfun({x}); + CHECK_EQ(out.size(), 1); + + auto& p = out[0].primitive(); + // NB: this test is brittle, will need to update + // it if we change compile conditions + CHECK_EQ(typeid(p), typeid(Reduce)); + auto cout = out[0].inputs()[0]; + auto& cp = cout.primitive(); + CHECK_EQ(typeid(cp), typeid(Compiled)); + CHECK_EQ(cout.inputs()[0].id(), x.id()); + } + + { + auto cfun = compile(unary_fused_3); + auto x = array({1.0, 2.0}); + auto out = cfun({x}); + + auto& p = out[0].primitive(); + CHECK_EQ(typeid(p), typeid(Compiled)); + auto sout = out[0].inputs()[0]; + CHECK_EQ(out[0].inputs().size(), 1); + auto& sp = sout.primitive(); + CHECK_EQ(typeid(sp), typeid(Reduce)); + CHECK_EQ(sout.inputs()[0].id(), x.id()); + } + + // Is equivalent works + { + auto out1 = compile(unary_fused_1)({array(1.0)}); + auto out2 = compile(unary_fused_1_copy)({array(1.0)}); + CHECK(out1[0].primitive().is_equivalent(out2[0].primitive())); + auto out3 = compile(unary_fused_1_diff)({array(1.0)}); + CHECK(!out1[0].primitive().is_equivalent(out3[0].primitive())); + } +} + +// All compilable +auto binary_fused_0(const std::vector& inputs) { + return std::vector{inputs[0] + inputs[1]}; +} + +// Binary into unary +auto binary_fused_1(const std::vector& inputs) { + return std::vector{abs(inputs[0] + inputs[1])}; +} + +// Binary into binary +auto binary_fused_2(const std::vector& inputs) { + auto x = inputs[0] + inputs[1]; + return std::vector{x + inputs[0]}; +} + +// Binary into unary into un-compilable +auto binary_fused_3(const std::vector& inputs) { + return std::vector{sum(abs(inputs[0] + inputs[1]), true)}; +} + +TEST_CASE("test compile binary fused") { + { + auto cfun = compile(binary_fused_0); + auto x = array(2.0); + auto y = array(2.0); + auto out = cfun({x, y})[0]; + + auto& p = out.primitive(); + CHECK_EQ(typeid(p), typeid(Add)); + CHECK_EQ(out.inputs()[0].id(), x.id()); + } + + { + auto cfun = compile(binary_fused_1); + auto x = array(2.0); + auto y = array(2.0); + auto out = cfun({x, y})[0]; + + auto& p = out.primitive(); + CHECK_EQ(typeid(p), typeid(Compiled)); + CHECK_EQ(out.inputs()[0].id(), x.id()); + CHECK_EQ(out.inputs()[1].id(), y.id()); + + auto expected_out = binary_fused_1({x, y})[0]; + CHECK_EQ(out.item(), expected_out.item()); + } + + { + auto cfun = compile(binary_fused_2); + auto x = array(2.0); + auto y = array(2.0); + auto out = cfun({x, y})[0]; + + auto& p = out.primitive(); + CHECK_EQ(typeid(p), typeid(Compiled)); + CHECK_EQ(out.inputs()[0].id(), x.id()); + CHECK_EQ(out.inputs()[1].id(), y.id()); + } + + { + auto cfun = compile(binary_fused_3); + auto x = array({1.0, 2.0}); + auto y = array({1.0, 2.0}); + auto out = cfun({x, y})[0]; + + auto& p = out.primitive(); + CHECK_EQ(typeid(p), typeid(Reduce)); + + auto cout = out.inputs()[0]; + auto& cp = cout.primitive(); + CHECK_EQ(typeid(cp), typeid(Compiled)); + CHECK_EQ(cout.inputs()[0].id(), x.id()); + CHECK_EQ(cout.inputs()[1].id(), y.id()); + } +} + +auto gelu_1(const std::vector& inputs) { + auto& x = inputs[0]; + auto out = x * (1.0f + erf(x / M_SQRT2)) / 2.0f; + return std::vector{out}; +} + +TEST_CASE("test compile gelu") { + { + auto cfun = compile(gelu_1); + auto x = array(1.0); + auto out = cfun({x})[0]; + auto& p = out.primitive(); + CHECK_EQ(typeid(p), typeid(Compiled)); + CHECK_EQ(out.inputs().size(), 4); + for (auto& in : out.inputs()) { + CHECK(in.inputs().empty()); + } + auto expected_out = gelu_1({x})[0]; + CHECK(allclose(out, expected_out).item()); + } + + { + auto cfun = compile(gelu_1); + auto x = array({1.0, 0.5}); + auto out = cfun({x})[0]; + auto& p = out.primitive(); + CHECK_EQ(typeid(p), typeid(Compiled)); + CHECK_EQ(out.inputs().size(), 4); + for (auto& in : out.inputs()) { + CHECK(in.inputs().empty()); + } + + auto expected_out = gelu_1({x})[0]; + CHECK(allclose(out, expected_out).item()); + } +} + +// Uncompilable input outside fused tape +auto unary_with_two_outputs(const std::vector& inputs) { + auto x = exp(inputs[0]); + return std::vector{exp(x), sum(x, true)}; +} + +auto uncompilable_inputs(const std::vector& inputs) { + auto& x = inputs[0]; + auto& y = inputs[1]; + return std::vector{x * abs(exp(y)), sum(x, true)}; +} + +auto uncompilable_inputs_order_matters(const std::vector& inputs) { + auto& x = inputs[0]; + auto& y = inputs[1]; + return std::vector{x / abs(exp(y)), sum(x, true)}; +} + +TEST_CASE("test compile tape with outside parents") { + { + auto cfun = compile(unary_with_two_outputs); + auto x = array({2.0, 2.0}); + auto out = cfun({x}); + + auto& p1 = out[0].primitive(); + CHECK_EQ(typeid(p1), typeid(Exp)); + auto& p2 = out[1].primitive(); + CHECK_EQ(typeid(p2), typeid(Reduce)); + } + + { + auto cfun = compile(uncompilable_inputs); + auto x = array({2.0, 2.0}); + auto y = array({1.6, 0.6}); + auto outs = cfun({x, y}); + + auto& p1 = outs[0].primitive(); + CHECK_EQ(typeid(p1), typeid(Compiled)); + auto& p2 = outs[1].primitive(); + CHECK_EQ(typeid(p2), typeid(Reduce)); + CHECK_EQ(outs[0].inputs().size(), 2); + + auto expected_outs = uncompilable_inputs({x, y}); + CHECK(allclose(outs[0], expected_outs[0]).item()); + CHECK(allclose(outs[1], expected_outs[1]).item()); + } + + { + auto cfun = compile(uncompilable_inputs_order_matters); + auto x = array({2.0, 2.0}); + auto y = array({1.6, 0.6}); + auto outs = cfun({x, y}); + + auto& p1 = outs[0].primitive(); + CHECK_EQ(typeid(p1), typeid(Compiled)); + auto& p2 = outs[1].primitive(); + CHECK_EQ(typeid(p2), typeid(Reduce)); + CHECK_EQ(outs[0].inputs().size(), 2); + + auto expected_outs = uncompilable_inputs_order_matters({x, y}); + CHECK(allclose(outs[0], expected_outs[0]).item()); + CHECK(allclose(outs[1], expected_outs[1]).item()); + } +} + +auto compile_accross_streams(const std::vector& inputs) { + auto s2 = new_stream(default_device()); + auto x = exp(abs(inputs[0])); + auto y = exp(abs(x, s2), s2); + return std::vector{y}; +} + +TEST_CASE("test compile accross streams") { + auto cfun = compile(compile_accross_streams); + auto x = array({2.0f}); + auto out = cfun({x})[0]; + auto& p1 = out.primitive(); + CHECK_EQ(typeid(p1), typeid(Compiled)); + CHECK_EQ(out.inputs().size(), 1); + auto child = out.inputs()[0]; + auto& p2 = child.primitive(); + CHECK_EQ(typeid(p2), typeid(Compiled)); + CHECK_EQ(child.inputs()[0].id(), x.id()); +} + +auto unary_compile_outputs(const std::vector& inputs) { + auto x = abs(inputs[0]); + auto y = square(x); + return std::vector{x, y}; +} + +auto binary_compile_outputs(const std::vector& inputs) { + auto x = inputs[0]; + auto y = inputs[1]; + x = x + y; + y = x + y; + return std::vector{x, y}; +} + +TEST_CASE("test compile internal output") { + { + auto cfun = compile(unary_compile_outputs); + auto x = array({3, -2}); + auto outs = cfun({x}); + auto& p1 = outs[0].primitive(); + CHECK_EQ(typeid(p1), typeid(Compiled)); + auto& p2 = outs[1].primitive(); + CHECK_EQ(typeid(p2), typeid(Compiled)); + CHECK_EQ(outs[0].siblings()[0].id(), outs[1].id()); + auto expected_outs = unary_compile_outputs({x}); + CHECK(array_equal(outs[0], expected_outs[0]).item()); + CHECK(array_equal(outs[1], expected_outs[1]).item()); + } + + { + auto cfun = compile(binary_compile_outputs); + auto x = array({3, -2}); + auto y = array({1, -1}); + auto outs = cfun({x, y}); + auto& p1 = outs[0].primitive(); + CHECK_EQ(typeid(p1), typeid(Compiled)); + auto& p2 = outs[1].primitive(); + CHECK_EQ(typeid(p2), typeid(Compiled)); + auto expected_outs = binary_compile_outputs({x, y}); + CHECK(array_equal(outs[0], expected_outs[0]).item()); + CHECK(array_equal(outs[1], expected_outs[1]).item()); + } +} + +auto deep_unary_compile(const std::vector& inputs) { + auto x = inputs[0]; + for (int i = 0; i < 10; ++i) { + x = cos(sin(x)); + } + return std::vector{x}; +} + +TEST_CASE("test compile deep graph") { + auto cfun = compile(deep_unary_compile); + auto x = array({3.0f, -2.0f}); + auto out = cfun({x})[0]; + auto expected_out = deep_unary_compile({x})[0]; + CHECK(allclose(out, expected_out).item()); +} + +auto repeat_input_to_compiled(const std::vector& inputs) { + auto x = abs(exp(inputs[0])); + auto y = abs(exp(sum(x))); + return std::vector{x + y}; +} + +TEST_CASE("test compile repeat input") { + auto cfun = compile(repeat_input_to_compiled); + auto x = array({3.0f, -2.0f}); + auto out = cfun({x})[0]; + auto expected_out = repeat_input_to_compiled({x})[0]; + CHECK(allclose(out, expected_out).item()); }