From cd1e5b25ccb95bb646d2f1449c3d848faf256a26 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 14 Jan 2024 14:26:53 -0800 Subject: [PATCH] compile binding --- mlx/array.cpp | 4 +- mlx/compile.cpp | 246 +++++++++++++++++++++++++++++------ mlx/transforms_impl.h | 6 + python/src/transforms.cpp | 47 +++++++ python/tests/test_compile.py | 22 ++++ 5 files changed, 285 insertions(+), 40 deletions(-) create mode 100644 python/tests/test_compile.py diff --git a/mlx/array.cpp b/mlx/array.cpp index f0e92c4e2..e19ed704d 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -169,7 +169,7 @@ array::ArrayDesc::ArrayDesc( dtype(dtype), primitive(std::move(primitive)), inputs(inputs) { - std::tie(size, strides) = cum_prod(shape); + std::tie(size, strides) = cum_prod(this->shape); for (auto& in : inputs) { is_tracer |= in.is_tracer(); } @@ -184,7 +184,7 @@ array::ArrayDesc::ArrayDesc( dtype(dtype), primitive(std::move(primitive)), inputs(std::move(inputs)) { - std::tie(size, strides) = cum_prod(shape); + std::tie(size, strides) = cum_prod(this->shape); for (auto& in : inputs) { is_tracer |= in.is_tracer(); } diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 223ba1673..91660effd 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -1,15 +1,21 @@ // Copyright © 2023 Apple Inc. #include // TODO +#include #include #include +#include "mlx/primitives.h" #include "mlx/transforms.h" #include "mlx/transforms_impl.h" namespace mlx::core { +namespace detail { + using CompileFn = std::function(const std::vector&)>; +using ParentsMap = + std::unordered_map>>; template size_t getAddress(std::function f) { @@ -28,9 +34,9 @@ struct CompilerCache { // Returns a reference to a CacheEntry which can be updated // by the caller to avoid copying large tapes / inputs / outputs - CacheEntry& find(const CompileFn& fn, const std::vector& inputs) { + CacheEntry& find(size_t fun_id, const std::vector& inputs) { // Try to find the entry - auto inserted = cache_.insert({getAddress(fn), {}}); + auto inserted = cache_.insert({fun_id, {}}); auto& entries = inserted.first->second; auto is_match = [](const std::vector& in1, const std::vector& in2) { @@ -93,38 +99,40 @@ std::pair, std::vector> compile_trace( return {tracer_inputs, fun(tracer_inputs)}; } -std::vector compile_dfs_graph( +// Traverses the graph to build a tape and a map of array ids to their parents +std::pair, ParentsMap> compile_dfs( const std::vector& inputs, const std::vector& outputs) { - std::unordered_set needs_compile; + std::function recurse; + std::vector tape; std::unordered_set cache; + std::unordered_map>> + parents_map; + std::unordered_set needs_compile; for (int i = 0; i < inputs.size(); ++i) { auto in = inputs[i]; needs_compile.insert(in.id()); cache.insert(in.id()); } - // Topologically sort the graph - std::vector tape; - - std::function recurse; - + // DFS the graph to build the tape, and log parents and scalars recurse = [&](const array& a) { auto id = a.id(); if (cache.find(id) != cache.end()) { return; } + for (int i = 0; i < a.inputs().size(); i++) { + auto& in = a.inputs()[i]; + parents_map[in.id()].push_back({a, i}); + for (auto& s : a.siblings()) { + parents_map[in.id()].push_back({s, i}); + } + recurse(in); + } cache.insert(id); for (auto& s : a.siblings()) { cache.insert(s.id()); } - - // Recurse on inputs - for (auto& input : a.inputs()) { - recurse(input); - } - // If any input needs a vmap, then the outputs also need - // a vmap for (auto& input : a.inputs()) { if (needs_compile.find(input.id()) != needs_compile.end()) { tape.push_back(a); @@ -136,16 +144,165 @@ std::vector compile_dfs_graph( } } }; - - for (auto& out : outputs) { - if (out.has_primitive()) { - recurse(out); - } + for (auto& a : outputs) { + recurse(a); } - return tape; + return {tape, parents_map}; } -std::vector compile_tape_replace( +// Simplify the tape. Note, this function modifies in-place both the tape and +// the parents map to remove orphaned arrays +void compile_simplify( + std::vector& tape, + ParentsMap& parents_map, + const std::vector& outputs, + int passes) { + // Helpers to identify identical scalars + std::map, array> scalars; + auto is_scalar = [](const array& a) { + return a.is_evaled() && a.ndim() == 0; + }; + auto get_scalar_rep = [](const array& a) { + uint64_t v = 0; + int dtype; + switch (a.dtype().size) { + case 1: + v = *a.data(); + break; + case 4: + v = *a.data(); + break; + case 8: + v = *a.data(); + break; + } + return std::make_pair(v, a.dtype().val); + }; + + for (auto& a : tape) { + if (is_scalar(a)) { + scalars.insert({get_scalar_rep(a), a}); + } + } + + // Helper that fuses two arrays in the graph by setting the parents of the + // source to point to the destination + auto fuse = [&](array& dst, array& src) { + // Canonicalize the order of the primitives outputs + auto sources = src.outputs(); + auto dests = dst.outputs(); + // For each src parent, point it to the corresponding dest + for (int i = 0; i < sources.size(); ++i) { + auto src_parents = parents_map.find(sources[i].id()); + if (src_parents == parents_map.end()) { + continue; + } + auto& pairs = parents_map[dests[i].id()]; + for (auto& parent : src_parents->second) { + parent.first.inputs()[parent.second] = dests[i]; + pairs.push_back(parent); + } + // Remove the source from the map to avoid fusing with it again + parents_map.erase(src_parents); + } + }; + + // Depth-1 array equivalence check. + auto array_equivalent = [](const array& a, const array& b) { + if (!a.has_primitive() || !b.has_primitive()) { + return false; + } + if (a.primitive_id() == b.primitive_id()) { + return false; + } + const auto& pa = a.primitive(); + const auto& pb = b.primitive(); + if (typeid(pa) != typeid(pb)) { + return false; + } + + if (a.inputs().size() != b.inputs().size()) { + return false; + } + + for (int i = 0; i < a.inputs().size(); i++) { + if (a.inputs()[i].id() != b.inputs()[i].id()) { + return false; + } + } + + return pa.is_equivalent(pb); + }; + + // Pass 0: fuse scalars + std::vector new_tape; + for (auto& arr : tape) { + // Check if we can fuse scalars + if (is_scalar(arr)) { + auto scalar = scalars.find(get_scalar_rep(arr)); + if (scalar->second.id() != arr.id()) { + fuse(scalar->second, arr); + // Don't keep orphaned scalars in the tape + continue; + } + } + new_tape.push_back(std::move(arr)); + } + + tape = std::move(new_tape); + + std::unordered_set output_set; + for (auto& o : outputs) { + output_set.insert(o.id()); + } + // Pass 1 to passes: fuse only keeping non-orphaned arrays in the tape + for (int pass = 0; pass < passes; ++pass) { + for (auto& arr : tape) { + // Helper to check if we can fuse the parents of the + // given array + // If an array has no parents and siblings have + auto maybe_fuse_parents = [&](auto& a) { + auto parents = parents_map.find(a.id()); + if (parents != parents_map.end()) { + auto N = parents->second.size(); + std::vector mask(N, false); + for (int i = 0; i < N; i++) { + if (mask[i]) { + continue; + } + for (int j = i + 1; j < N; j++) { + if (mask[j]) { + continue; + } + auto& src = parents->second[j].first; + auto& dst = parents->second[i].first; + if (src.id() != dst.id() && array_equivalent(src, dst)) { + fuse(dst, src); + mask[j] = true; + } + } + } + return false; + } else { + return output_set.find(a.id()) != output_set.end(); + } + }; + + bool discard = maybe_fuse_parents(arr); + for (auto& s : arr.siblings()) { + discard &= maybe_fuse_parents(s); + } + // If an array and its siblings have no parents, and none of them are + // outputs, it is safe to remove it from the tape + if (!discard) { + new_tape.push_back(std::move(arr)); + } + } + tape = std::move(new_tape); + } +} + +std::vector compile_replace( const std::vector& tape, const std::vector& trace_inputs, const std::vector& trace_outputs, @@ -155,7 +312,6 @@ std::vector compile_tape_replace( trace_to_real.insert({trace_inputs[i].id(), inputs[i]}); } - // We need a map here of traced inputs to real inputs for (auto& a : tape) { if (!a.has_primitive()) { std::runtime_error( @@ -177,36 +333,50 @@ std::vector compile_tape_replace( } std::function(const std::vector&)> compile( - const std::function(const std::vector&)>& fun) { - // std::cout << getAddress(fun) << std::endl; - return [&fun](const std::vector& inputs) { + const std::function(const std::vector&)>& fun, + size_t fun_id) { + return [&fun, fun_id](const std::vector& inputs) { // Find a cache entry with the correct inputs - auto& entry = compiler_cache().find(fun, inputs); + auto& entry = compiler_cache().find(fun_id, inputs); // No matching cache entry existed, so compile if (entry.empty) { - std::cout << "RECOMPILING? " << std::endl; // Mark the entry as not empty since we are about to fill it entry.empty = false; // Trace te build the graph std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs); - // This is a good point to do optimizations: - // - simplify - // - kernel fusion to generate new primitives - // - may make sense to keep the tape from simplify - // and pass it around so that we don't have to keep rebuilding it + // DFS the graph and get a tape, and a map of array id to (parent, + // position in parent inputs) + std::unordered_map>> + parents_map; + std::tie(entry.tape, parents_map) = + compile_dfs(entry.inputs, entry.outputs); - // Recurse to build the tape - entry.tape = compile_dfs_graph(entry.inputs, entry.outputs); + // Simplify the tape + // compile_simplify(entry.tape, parents_map, entry.outputs, /* passes */ + // 2); + + // This is a good point to do more optimizations, e.g. kernel fusion to + // generate new primitives. The tape needs to be updated accordingly } // At this point we must have a tape, now replace the placeholders // with real arrays that can be evaluated - return compile_tape_replace( - entry.tape, entry.inputs, entry.outputs, inputs); + return compile_replace(entry.tape, entry.inputs, entry.outputs, inputs); }; } +} // namespace detail + +std::function(const std::vector&)> compile( + const std::function(const std::vector&)>& fun) { + auto fun_id = detail::getAddress(fun); + if (fun_id == 0) { + throw std::invalid_argument( + "[compile] Cannot compile a non-addressable function."); + } + return detail::compile(fun, fun_id); +} } // namespace mlx::core diff --git a/mlx/transforms_impl.h b/mlx/transforms_impl.h index 4b464bafd..5b9c7ada5 100644 --- a/mlx/transforms_impl.h +++ b/mlx/transforms_impl.h @@ -14,6 +14,12 @@ std::vector vmap_replace( const std::vector& in_axes, const std::vector& out_axes); +// This is not part of the general C++ API as calling with a bad id is a bad +// idea. +std::function(const std::vector&)> compile( + const std::function(const std::vector&)>& fun, + size_t fun_id); + // Create an InTracing object during tracing operations to signify to the rest // of the codebase that we are during tracing so evals should not throw away // the graph. diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 0d217a9f6..043dd8ea5 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -1,4 +1,5 @@ // Copyright © 2023 Apple Inc. +#include // TODO #include #include @@ -437,6 +438,34 @@ auto py_vmap( }; } +auto py_compile(const py::function& fun) { + return [fun](const py::args& args) { + // Inputs must be array or tree of arrays + auto inputs = tree_flatten(args, true); + + // py_value_out will hold the output of the python function in order to be + // able to reconstruct the python tree of extra return values + py::object py_outputs; + + auto compile_fun = + [&fun, &args, &inputs, &py_outputs](const std::vector& a) { + // Call the python function + py_outputs = fun(*tree_unflatten(args, a)); + + // Flatten the outputs + return tree_flatten(py_outputs, true); + }; + + // Compile and call + // TODO, awni, I think this cast is ok?? + size_t fun_id = reinterpret_cast(fun.ptr()); + auto outputs = detail::compile(compile_fun, fun_id)(inputs); + + // Put the outputs back in the container + return tree_unflatten(py_outputs, outputs); + }; +} + void init_transforms(py::module_& m) { py::options options; options.disable_function_signatures(); @@ -736,4 +765,22 @@ void init_transforms(py::module_& m) { } }, "file"_a); + m.def( + "compile", + [](const py::function& fun) { return py::cpp_function(py_compile(fun)); }, + "fun"_a, + R"pbdoc( + compile(fun: function) -> function + + Returns a compiled function which produces the same output as ``fun``. + + Args: + fun (function): A function which takes a variable number of + :class:`array` or trees of :class:`array` and returns + a variable number of :class:`array` or trees of :class:`array`. + + Returns: + function: A compiled function which has the same input arguments + as ``fun`` and returns the the same output(s). + )pbdoc"); } diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py new file mode 100644 index 000000000..39e075d72 --- /dev/null +++ b/python/tests/test_compile.py @@ -0,0 +1,22 @@ +# Copyright © 2023-2024 Apple Inc. + +import unittest + +import mlx.core as mx +import mlx_tests + + +class TestCompile(mlx_tests.MLXTestCase): + def test_simple_compile(self): + def fun(x, y): + return x + y + + compiled_fn = mx.compile(fun) + compiled_fn = mx.compile(fun) + x = mx.array(1.0) + y = mx.array(1.0) + # out = compiled_fn(x, y) + + +if __name__ == "__main__": + unittest.main()