diff --git a/mlx/array.cpp b/mlx/array.cpp index ad5d4f122..f0e92c4e2 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -47,6 +47,17 @@ array::array( std::move(primitive), inputs)) {} +array::array( + std::vector shape, + Dtype dtype, + std::shared_ptr primitive, + std::vector&& inputs) + : array_desc_(std::make_shared( + std::move(shape), + dtype, + std::move(primitive), + std::move(inputs))) {} + std::vector array::make_arrays( const std::vector>& shapes, const std::vector& dtypes, @@ -164,6 +175,21 @@ array::ArrayDesc::ArrayDesc( } } +array::ArrayDesc::ArrayDesc( + std::vector&& shape, + Dtype dtype, + std::shared_ptr primitive, + std::vector&& inputs) + : shape(std::move(shape)), + dtype(dtype), + primitive(std::move(primitive)), + inputs(std::move(inputs)) { + std::tie(size, strides) = cum_prod(shape); + for (auto& in : inputs) { + is_tracer |= in.is_tracer(); + } +} + array::ArrayIterator::ArrayIterator(const array& arr, int idx) : arr(arr), idx(idx) { if (arr.ndim() == 0) { diff --git a/mlx/array.h b/mlx/array.h index 0d530162c..8f345bd3f 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -172,6 +172,12 @@ class array { std::shared_ptr primitive, const std::vector& inputs); + array( + std::vector shape, + Dtype dtype, + std::shared_ptr primitive, + std::vector&& inputs); + static std::vector make_arrays( const std::vector>& shapes, const std::vector& dtypes, @@ -215,6 +221,11 @@ class array { return *(array_desc_->primitive); }; + /** A shared pointer to the array's primitive. */ + std::shared_ptr& primitive_ptr() const { + return array_desc_->primitive; + }; + /** Check if the array has an attached primitive or is a leaf node. */ bool has_primitive() const { return array_desc_->primitive != nullptr; @@ -360,6 +371,12 @@ class array { Dtype dtype, std::shared_ptr primitive, const std::vector& inputs); + + explicit ArrayDesc( + std::vector&& shape, + Dtype dtype, + std::shared_ptr primitive, + std::vector&& inputs); }; // The ArrayDesc contains the details of the materialized array including the diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 48ce8b806..223ba1673 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -1,12 +1,15 @@ // Copyright © 2023 Apple Inc. #include // TODO + +#include +#include + #include "mlx/transforms.h" +#include "mlx/transforms_impl.h" namespace mlx::core { -// class CompilerCache { -// std::unordered_map -// } +using CompileFn = std::function(const std::vector&)>; template size_t getAddress(std::function f) { @@ -15,32 +18,194 @@ size_t getAddress(std::function f) { return (size_t)*fnPointer; } -int g(int, int) { - return 2; +struct CompilerCache { + struct CacheEntry { + std::vector inputs; + std::vector outputs; + std::vector tape; + bool empty{true}; + }; + + // Returns a reference to a CacheEntry which can be updated + // by the caller to avoid copying large tapes / inputs / outputs + CacheEntry& find(const CompileFn& fn, const std::vector& inputs) { + // Try to find the entry + auto inserted = cache_.insert({getAddress(fn), {}}); + auto& entries = inserted.first->second; + auto is_match = [](const std::vector& in1, + const std::vector& in2) { + if (in1.size() != in2.size()) { + throw std::runtime_error( + "[compiler] Got different number of inputs to function," + " this should never happen."); + } + for (int i = 0; i < in1.size(); ++i) { + if (in1[i].shape() != in2[i].shape()) { + return false; + } + if (in1[i].dtype() != in2[i].dtype()) { + return false; + } + } + return true; + }; + + // Loop over entries and check inputs match i.e. shapes and types must be + // equal. Note this could get really slow if one compiles the same + // function with many different shapes. May want to store entries in a + // more easily searchable structure. + for (auto& entry : entries) { + // Check the inputs match and return if so + if (is_match(inputs, entry.inputs)) { + return entry; + } + } + // Otherwise append a new cache entry + entries.push_back(CacheEntry{}); + return entries.back(); + }; + + private: + CompilerCache() {} + friend CompilerCache& compiler_cache(); + std::unordered_map> cache_; +}; + +CompilerCache& compiler_cache() { + static CompilerCache compiler_cache_; + return compiler_cache_; +} + +std::pair, std::vector> compile_trace( + const std::function(const std::vector&)>& fun, + const std::vector& inputs) { + // Set the global tracing flag. + detail::InTracing in_tracing; + + // Run the function on placeholder inputs + // to get compute graph + std::vector tracer_inputs; + for (int i = 0; i < inputs.size(); ++i) { + array in(inputs[i].shape(), inputs[i].dtype(), nullptr, {}); + in.set_tracer(true); + tracer_inputs.push_back(std::move(in)); + } + return {tracer_inputs, fun(tracer_inputs)}; +} + +std::vector compile_dfs_graph( + const std::vector& inputs, + const std::vector& outputs) { + std::unordered_set needs_compile; + std::unordered_set cache; + for (int i = 0; i < inputs.size(); ++i) { + auto in = inputs[i]; + needs_compile.insert(in.id()); + cache.insert(in.id()); + } + + // Topologically sort the graph + std::vector tape; + + std::function recurse; + + recurse = [&](const array& a) { + auto id = a.id(); + if (cache.find(id) != cache.end()) { + return; + } + cache.insert(id); + for (auto& s : a.siblings()) { + cache.insert(s.id()); + } + + // Recurse on inputs + for (auto& input : a.inputs()) { + recurse(input); + } + // If any input needs a vmap, then the outputs also need + // a vmap + for (auto& input : a.inputs()) { + if (needs_compile.find(input.id()) != needs_compile.end()) { + tape.push_back(a); + needs_compile.insert(a.id()); + for (auto s : a.siblings()) { + needs_compile.insert(s.id()); + } + break; + } + } + }; + + for (auto& out : outputs) { + if (out.has_primitive()) { + recurse(out); + } + } + return tape; +} + +std::vector compile_tape_replace( + const std::vector& tape, + const std::vector& trace_inputs, + const std::vector& trace_outputs, + const std::vector& inputs) { + std::unordered_map trace_to_real; + for (int i = 0; i < inputs.size(); ++i) { + trace_to_real.insert({trace_inputs[i].id(), inputs[i]}); + } + + // We need a map here of traced inputs to real inputs + for (auto& a : tape) { + if (!a.has_primitive()) { + std::runtime_error( + "[compile] Something went wrong, no primitive in tape"); + } + std::vector real_inputs; + for (auto& in : a.inputs()) { + real_inputs.push_back(trace_to_real.at(in.id())); + } + auto real_a = + array(a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs)); + trace_to_real.insert({a.id(), std::move(real_a)}); + } + std::vector outputs; + for (auto& o : trace_outputs) { + outputs.push_back(trace_to_real.at(o.id())); + } + return outputs; } std::function(const std::vector&)> compile( const std::function(const std::vector&)>& fun) { - // Not doing too much at the moment // std::cout << getAddress(fun) << std::endl; return [&fun](const std::vector& inputs) { - std::cout << getAddress(fun) << std::endl; - // getAddress(std::function(g)); - // - // std::cout << getAddress(fun) << std::endl; - // Step 1 check the cache for the function. - // If it's in the cache check the shapes and types - // If they match then run the cached function, - // - // What exactly is the cached function? - // The return has to be the outputs of fun(inputs) which point to the - // correct inputs So we need to store a tape of primitives -> inputs (shape, - // dtype), outputs (shape, dtype) We need a level of indirection id to input - // to store the inputs so we can - // T - // Because eval will just want some pointers to arrays - // So you go through and set the - return fun(inputs); + // Find a cache entry with the correct inputs + auto& entry = compiler_cache().find(fun, inputs); + + // No matching cache entry existed, so compile + if (entry.empty) { + std::cout << "RECOMPILING? " << std::endl; + // Mark the entry as not empty since we are about to fill it + entry.empty = false; + + // Trace te build the graph + std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs); + + // This is a good point to do optimizations: + // - simplify + // - kernel fusion to generate new primitives + // - may make sense to keep the tape from simplify + // and pass it around so that we don't have to keep rebuilding it + + // Recurse to build the tape + entry.tape = compile_dfs_graph(entry.inputs, entry.outputs); + } + + // At this point we must have a tape, now replace the placeholders + // with real arrays that can be evaluated + return compile_tape_replace( + entry.tape, entry.inputs, entry.outputs, inputs); }; } diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index e83f5179e..d2f825235 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -14,4 +14,7 @@ TEST_CASE("test simple compile") { auto compfn = compile(simple_fun); auto out = compfn({array(1.0), array(2.0)})[0]; CHECK_EQ(out.item(), 3.0f); + + out = compfn({array(1.0), array(2.0)})[0]; + CHECK_EQ(out.item(), 3.0f); }