diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index bb546616c..cb65e518b 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -5,6 +5,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp ${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp diff --git a/mlx/array.cpp b/mlx/array.cpp index ad5d4f122..e19ed704d 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -47,6 +47,17 @@ array::array( std::move(primitive), inputs)) {} +array::array( + std::vector shape, + Dtype dtype, + std::shared_ptr primitive, + std::vector&& inputs) + : array_desc_(std::make_shared( + std::move(shape), + dtype, + std::move(primitive), + std::move(inputs))) {} + std::vector array::make_arrays( const std::vector>& shapes, const std::vector& dtypes, @@ -158,7 +169,22 @@ array::ArrayDesc::ArrayDesc( dtype(dtype), primitive(std::move(primitive)), inputs(inputs) { - std::tie(size, strides) = cum_prod(shape); + std::tie(size, strides) = cum_prod(this->shape); + for (auto& in : inputs) { + is_tracer |= in.is_tracer(); + } +} + +array::ArrayDesc::ArrayDesc( + std::vector&& shape, + Dtype dtype, + std::shared_ptr primitive, + std::vector&& inputs) + : shape(std::move(shape)), + dtype(dtype), + primitive(std::move(primitive)), + inputs(std::move(inputs)) { + std::tie(size, strides) = cum_prod(this->shape); for (auto& in : inputs) { is_tracer |= in.is_tracer(); } diff --git a/mlx/array.h b/mlx/array.h index 0d530162c..8f345bd3f 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -172,6 +172,12 @@ class array { std::shared_ptr primitive, const std::vector& inputs); + array( + std::vector shape, + Dtype dtype, + std::shared_ptr primitive, + std::vector&& inputs); + static std::vector make_arrays( const std::vector>& shapes, const std::vector& dtypes, @@ -215,6 +221,11 @@ class array { return *(array_desc_->primitive); }; + /** A shared pointer to the array's primitive. */ + std::shared_ptr& primitive_ptr() const { + return array_desc_->primitive; + }; + /** Check if the array has an attached primitive or is a leaf node. */ bool has_primitive() const { return array_desc_->primitive != nullptr; @@ -360,6 +371,12 @@ class array { Dtype dtype, std::shared_ptr primitive, const std::vector& inputs); + + explicit ArrayDesc( + std::vector&& shape, + Dtype dtype, + std::shared_ptr primitive, + std::vector&& inputs); }; // The ArrayDesc contains the details of the materialized array including the diff --git a/mlx/compile.cpp b/mlx/compile.cpp new file mode 100644 index 000000000..19a0d368f --- /dev/null +++ b/mlx/compile.cpp @@ -0,0 +1,440 @@ +// Copyright © 2023-2024 Apple Inc. +#include +#include +#include +#include + +#include "mlx/allocator.h" +#include "mlx/primitives.h" +#include "mlx/transforms.h" +#include "mlx/transforms_impl.h" + +namespace mlx::core { + +namespace detail { + +bool& compiler_disabled() { + auto get_val = []() { + if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILE")) { + return true; + } else { + return false; + } + }; + static bool compiler_disabled_ = get_val(); + return compiler_disabled_; +} + +#define MAX_OPS_PER_BUFFER max_ops_per_buffer() + +using CompileFn = std::function(const std::vector&)>; +using ParentsMap = + std::unordered_map>>; + +template +size_t getAddress(std::function f) { + typedef T(fnType)(U...); + fnType** fnPointer = f.template target(); + if (fnPointer == nullptr) { + throw std::invalid_argument( + "[compile] Cannot compile a non-addressable function."); + } + return (size_t)*fnPointer; +} + +struct CompilerCache { + struct CacheEntry { + std::vector inputs; + std::vector outputs; + std::vector tape; + bool empty{true}; + }; + + // Returns a reference to a CacheEntry which can be updated + // by the caller to avoid copying large tapes / inputs / outputs + CacheEntry& find(size_t fun_id, const std::vector& inputs) { + // Try to find the entry + auto [entry_it, inserted] = cache_.insert({fun_id, {}}); + auto& entries = entry_it->second; + auto is_match = [](const std::vector& in1, + const std::vector& in2) { + if (in1.size() != in2.size()) { + throw std::runtime_error( + "[compiler] Got different number of inputs to function," + " this should never happen."); + } + for (int i = 0; i < in1.size(); ++i) { + if (in1[i].shape() != in2[i].shape()) { + return false; + } + if (in1[i].dtype() != in2[i].dtype()) { + return false; + } + } + return true; + }; + + // Loop over entries and check inputs match i.e. shapes and types must be + // equal. Note this could get really slow if one compiles the same + // function with many different shapes. May want to store entries in a + // more easily searchable structure. + for (auto& entry : entries) { + // Check the inputs match and return if so + if (is_match(inputs, entry.inputs)) { + return entry; + } + } + // Otherwise append a new cache entry + entries.push_back(CacheEntry{}); + return entries.back(); + }; + + void erase(size_t fun_id) { + cache_.erase(fun_id); + } + + private: + CompilerCache() { + // Make sure the allocator is fully + // initialized before the compiler cache + allocator::allocator(); + } + friend CompilerCache& compiler_cache(); + std::unordered_map> cache_; +}; + +CompilerCache& compiler_cache() { + static CompilerCache compiler_cache_; + return compiler_cache_; +} + +std::pair, std::vector> compile_trace( + const std::function(const std::vector&)>& fun, + const std::vector& inputs) { + // Set the global tracing flag. + detail::InTracing in_tracing; + + // Run the function on placeholder inputs + // to get compute graph + std::vector tracer_inputs; + for (int i = 0; i < inputs.size(); ++i) { + array in(inputs[i].shape(), inputs[i].dtype(), nullptr, {}); + in.set_tracer(true); + tracer_inputs.push_back(std::move(in)); + } + return {tracer_inputs, fun(tracer_inputs)}; +} + +// Traverses the graph to build a tape and a map of array ids to their parents +std::pair, ParentsMap> compile_dfs( + const std::vector& inputs, + const std::vector& outputs) { + std::function recurse; + std::vector tape; + std::unordered_set input_set; + std::unordered_map>> + parents_map; + for (int i = 0; i < inputs.size(); ++i) { + auto in = inputs[i]; + input_set.insert(in.id()); + } + + // DFS the graph to build the tape, and log parents and scalars + std::unordered_set cache; + recurse = [&](const array& a) { + auto id = a.id(); + if (cache.find(id) != cache.end()) { + return; + } + for (int i = 0; i < a.inputs().size(); i++) { + auto& in = a.inputs()[i]; + parents_map[in.id()].push_back({a, i}); + for (auto& s : a.siblings()) { + parents_map[in.id()].push_back({s, i}); + } + // Don't recurse on inputs (but add them to the tape for the purpose + // of future optimizations) + if (input_set.find(a.id()) == input_set.end()) { + recurse(in); + } + } + cache.insert(id); + for (auto& s : a.siblings()) { + cache.insert(s.id()); + } + tape.push_back(a); + }; + for (auto& a : outputs) { + recurse(a); + } + return {tape, parents_map}; +} + +// Simplify the tape. Note, this function modifies in-place both the tape and +// the parents map to remove orphaned arrays +void compile_simplify( + std::vector& tape, + ParentsMap& parents_map, + const std::vector& outputs, + int passes) { + // Helpers to identify identical scalars + std::map, array> scalars; + auto is_scalar = [](const array& a) { + return a.is_evaled() && a.ndim() == 0; + }; + auto get_scalar_rep = [](const array& a) { + uint64_t v = 0; + int dtype; + switch (a.dtype().size) { + case 1: + v = *a.data(); + break; + case 4: + v = *a.data(); + break; + case 8: + v = *a.data(); + break; + } + return std::make_pair(v, a.dtype().val); + }; + + for (auto& a : tape) { + if (is_scalar(a)) { + scalars.insert({get_scalar_rep(a), a}); + } + } + + // Helper that fuses two arrays in the graph by setting the parents of the + // source to point to the destination + auto fuse = [&](array& dst, array& src) { + // Canonicalize the order of the primitives outputs + auto sources = src.outputs(); + auto dests = dst.outputs(); + // For each src parent, point it to the corresponding dest + for (int i = 0; i < sources.size(); ++i) { + auto src_parents = parents_map.find(sources[i].id()); + if (src_parents == parents_map.end()) { + continue; + } + auto& pairs = parents_map[dests[i].id()]; + for (auto& parent : src_parents->second) { + parent.first.inputs()[parent.second] = dests[i]; + pairs.push_back(parent); + } + // Remove the source from the map to avoid fusing with it again + parents_map.erase(src_parents); + } + }; + + // Depth-1 array equivalence check. + auto array_equivalent = [](const array& a, const array& b) { + if (!a.has_primitive() || !b.has_primitive()) { + return false; + } + if (a.primitive_id() == b.primitive_id()) { + return false; + } + const auto& pa = a.primitive(); + const auto& pb = b.primitive(); + if (typeid(pa) != typeid(pb)) { + return false; + } + + if (a.inputs().size() != b.inputs().size()) { + return false; + } + + for (int i = 0; i < a.inputs().size(); i++) { + if (a.inputs()[i].id() != b.inputs()[i].id()) { + return false; + } + } + + return pa.is_equivalent(pb); + }; + + // Pass 0: fuse scalars + std::vector new_tape; + for (auto& arr : tape) { + // Check if we can fuse scalars + if (is_scalar(arr)) { + auto scalar = scalars.find(get_scalar_rep(arr)); + if (scalar->second.id() != arr.id()) { + fuse(scalar->second, arr); + // Don't keep orphaned scalars in the tape + continue; + } + } + new_tape.push_back(std::move(arr)); + } + + tape = std::move(new_tape); + + std::unordered_set output_set; + for (auto& o : outputs) { + output_set.insert(o.id()); + } + // Pass 1..passes: fuse only keeping non-orphaned arrays in the tape + for (int pass = 0; pass < passes; ++pass) { + for (auto& arr : tape) { + // Helper to check if we can fuse the parents of the + // given array + auto maybe_fuse_parents = [&](auto& a) { + auto parents = parents_map.find(a.id()); + if (parents != parents_map.end()) { + auto N = parents->second.size(); + std::vector mask(N, false); + for (int i = 0; i < N; i++) { + if (mask[i]) { + continue; + } + for (int j = i + 1; j < N; j++) { + if (mask[j]) { + continue; + } + auto& src = parents->second[j].first; + auto& dst = parents->second[i].first; + if (src.id() != dst.id() && array_equivalent(src, dst)) { + fuse(dst, src); + mask[j] = true; + } + } + } + // Erase orphaned parents so we don't keep fusing with them + for (int i = N - 1; i > 0; --i) { + if (mask[i]) { + parents->second.erase(parents->second.begin() + i); + } + } + return false; + } else { + return output_set.find(a.id()) == output_set.end(); + } + }; + + bool discard = maybe_fuse_parents(arr); + for (auto& s : arr.siblings()) { + discard &= maybe_fuse_parents(s); + } + // If an array and its siblings have no parents, and none of them are + // outputs, it is safe to remove it from the tape + if (!discard) { + new_tape.push_back(std::move(arr)); + } + } + tape = std::move(new_tape); + } +} + +std::vector compile_replace( + const std::vector& tape, + const std::vector& trace_inputs, + const std::vector& trace_outputs, + const std::vector& inputs) { + std::unordered_map trace_to_real; + for (int i = 0; i < inputs.size(); ++i) { + trace_to_real.insert({trace_inputs[i].id(), inputs[i]}); + } + + for (auto& a : tape) { + // Arrays in the tape without primitives are constants + // and can be used directly + if (!a.has_primitive()) { + trace_to_real.insert({a.id(), a}); + } else { + // Find real inputs + std::vector real_inputs; + for (auto& in : a.inputs()) { + real_inputs.push_back(trace_to_real.at(in.id())); + } + if (a.siblings().empty()) { + auto real_a = array( + a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs)); + trace_to_real.insert({a.id(), std::move(real_a)}); + } else { + // Ensure the order is correct for multi-output primitives + std::vector> shapes; + std::vector types; + auto trace_out = a.outputs(); + for (auto& o : trace_out) { + shapes.push_back(o.shape()); + types.push_back(o.dtype()); + } + auto real_out = + array::make_arrays(shapes, types, a.primitive_ptr(), real_inputs); + for (int i = 0; i < trace_out.size(); ++i) { + trace_to_real.insert({trace_out[i].id(), std::move(real_out[i])}); + } + } + } + } + + std::vector outputs; + for (auto& o : trace_outputs) { + outputs.push_back(trace_to_real.at(o.id())); + } + return outputs; +} + +std::function(const std::vector&)> compile( + const std::function(const std::vector&)>& fun, + size_t fun_id) { + if (compiler_disabled()) { + return fun; + } + return [fun, fun_id](const std::vector& inputs) { + // Find a cache entry with the correct inputs + auto& entry = compiler_cache().find(fun_id, inputs); + + // No matching cache entry existed, so compile + if (entry.empty) { + // Mark the entry as not empty since we are about to fill it + entry.empty = false; + // Trace to build the graph + std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs); + + // DFS the graph and get a tape, and a map of array id to (parent, + // position in parent inputs) + std::unordered_map>> + parents_map; + std::tie(entry.tape, parents_map) = + compile_dfs(entry.inputs, entry.outputs); + + // Simplify the tape + compile_simplify(entry.tape, parents_map, entry.outputs, /* passes */ 3); + + // This is a good point to do more optimizations, e.g. kernel fusion to + // generate new primitives. The tape needs to be updated accordingly + } + + // At this point we must have a tape, now replace the placeholders + // with real arrays that can be evaluated + return compile_replace(entry.tape, entry.inputs, entry.outputs, inputs); + }; +} + +void compile_erase(size_t fun_id) { + detail::compiler_cache().erase(fun_id); +} + +} // namespace detail + +std::function(const std::vector&)> compile( + const std::function(const std::vector&)>& fun) { + if (detail::compiler_disabled()) { + return fun; + } + auto fun_id = detail::getAddress(fun); + return detail::compile(fun, fun_id); +} + +void disable_compile() { + detail::compiler_disabled() = true; +} + +void enable_compile() { + detail::compiler_disabled() = false; +} + +} // namespace mlx::core diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 03d8d993b..86cf69989 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -1,7 +1,6 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include #include -#include #include #include #include @@ -35,169 +34,6 @@ class Synchronizer : public Primitive { // are currently under a function transformation. int detail::InTracing::tracing_counter{0}; -void simplify(const std::vector& outputs) { - // Some notes about how this function works - // - // Step 1: Traverse the graph and build a tape. During the graph - // traversal we: - // - Build a map of inputs to their parents. - // - Record scalar inputs in a map in order to fuse them. - // Step 2: Process the tape. A node in the tape has inputs and outputs. - // - Scalar inputs are replaced with their canonical scalar - // - We check each inputs output nodes. Every output node that matches - // the current node gets fused into the current node. - std::function recurse; - std::queue tape; - std::unordered_set cache; - std::unordered_map>> - parents_map; - - // Helpers to identify identical scalars - std::map, array> scalars; - auto is_scalar = [](const array& a) { - return a.is_evaled() && a.ndim() == 0; - }; - auto get_scalar_rep = [](const array& a) { - uint64_t v = 0; - int dtype; - switch (a.dtype().size) { - case 1: - v = *a.data(); - break; - case 4: - v = *a.data(); - break; - case 8: - v = *a.data(); - break; - } - return std::make_pair(v, a.dtype().val); - }; - - // DFS the graph to build the tape, and log parents and scalars - recurse = [&](const array& a) { - auto id = a.id(); - if (cache.find(id) != cache.end()) { - return; - } - for (int i = 0; i < a.inputs().size(); i++) { - auto& in = a.inputs()[i]; - parents_map[in.id()].push_back({a, i}); - for (auto& s : a.siblings()) { - parents_map[in.id()].push_back({s, i}); - } - recurse(in); - } - cache.insert(id); - for (auto& s : a.siblings()) { - cache.insert(s.id()); - } - - tape.push(a); - if (is_scalar(a)) { - scalars.insert({get_scalar_rep(a), a}); - } - }; - for (auto& a : outputs) { - recurse(a); - } - - // Helper that fuses two arrays in the graph by setting the parents of the - // source to point to the destination - auto fuse = [&](array& dst, array& src) { - // Canonicalize the order of the primitives outputs - auto sources = src.outputs(); - auto dests = dst.outputs(); - // For each src parent, point it to the corresponding dest - for (int i = 0; i < sources.size(); ++i) { - auto src_parents = parents_map.find(sources[i].id()); - if (src_parents == parents_map.end()) { - continue; - } - auto& pairs = parents_map[dests[i].id()]; - for (auto& parent : src_parents->second) { - parent.first.inputs()[parent.second] = dests[i]; - pairs.push_back(parent); - } - // Remove the source from the map to avoid fusing with it again - parents_map.erase(src_parents); - } - }; - - // Depth-1 array equivalence check. - auto array_equivalent = [](const array& a, const array& b) { - if (!a.has_primitive() || !b.has_primitive()) { - return false; - } - if (a.primitive_id() == b.primitive_id()) { - return false; - } - const auto& pa = a.primitive(); - const auto& pb = b.primitive(); - if (typeid(pa) != typeid(pb)) { - return false; - } - - if (a.inputs().size() != b.inputs().size()) { - return false; - } - - for (int i = 0; i < a.inputs().size(); i++) { - if (a.inputs()[i].id() != b.inputs()[i].id()) { - return false; - } - } - - return pa.is_equivalent(pb); - }; - - // Walk the graph - while (!tape.empty()) { - auto arr = std::move(tape.front()); - tape.pop(); - - // Check if we can fuse scalars - if (is_scalar(arr)) { - auto scalar = scalars.find(get_scalar_rep(arr)); - if (scalar->second.id() != arr.id()) { - fuse(scalar->second, arr); - arr = scalar->second; - } - } - - // Helper to check if we can fuse the parents of the - // given array - auto maybe_fuse_parents = [&](auto& a) { - auto parents = parents_map.find(a.id()); - if (parents != parents_map.end()) { - auto N = parents->second.size(); - std::vector mask(N, false); - for (int i = 0; i < N; i++) { - if (mask[i]) { - continue; - } - for (int j = i + 1; j < N; j++) { - if (mask[j]) { - continue; - } - auto& src = parents->second[j].first; - auto& dst = parents->second[i].first; - if (src.id() != dst.id() && array_equivalent(src, dst)) { - fuse(dst, src); - mask[j] = true; - } - } - } - } - }; - - maybe_fuse_parents(arr); - for (auto& s : arr.siblings()) { - maybe_fuse_parents(s); - } - } -} - void eval(const std::vector& outputs) { std::function recurse; std::queue tape; diff --git a/mlx/transforms.h b/mlx/transforms.h index 813c5f7fd..7a13b8f7e 100644 --- a/mlx/transforms.h +++ b/mlx/transforms.h @@ -1,18 +1,25 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #pragma once -#include "array.h" +#include "mlx/array.h" namespace mlx::core { -/** Fuse equivalent arrays to avoid duplicate execution. */ -void simplify(const std::vector& outputs); +// Compile takes a function and returns a new function +std::function(const std::vector&)> compile( + const std::function(const std::vector&)>& fun); -template -void simplify(Arrays... outputs) { - simplify(std::vector{std::forward(outputs)...}); -} +/** Globally disable compilation. + * Setting the environment variable ``MLX_DISABLE_COMPILE`` can also + * be used to disable compilation. + */ +void disable_compile(); + +/** Globally enable compilation. + * This will override the environment variable ``MLX_DISABLE_COMPILE``. + */ +void enable_compile(); void eval(const std::vector& outputs); diff --git a/mlx/transforms_impl.h b/mlx/transforms_impl.h index 4b464bafd..a1b3461ab 100644 --- a/mlx/transforms_impl.h +++ b/mlx/transforms_impl.h @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. namespace mlx::core::detail { @@ -14,6 +14,15 @@ std::vector vmap_replace( const std::vector& in_axes, const std::vector& out_axes); +// This is not part of the general C++ API as calling with a bad id is a bad +// idea. +std::function(const std::vector&)> compile( + const std::function(const std::vector&)>& fun, + size_t fun_id); + +// Erase cached compile functions +void compile_erase(size_t fun_id); + // Create an InTracing object during tracing operations to signify to the rest // of the codebase that we are during tracing so evals should not throw away // the graph. diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 0d217a9f6..8706706d6 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -1,5 +1,4 @@ -// Copyright © 2023 Apple Inc. - +// Copyright © 2023-2024 Apple Inc. #include #include #include @@ -163,6 +162,19 @@ py::object tree_unflatten( }); } +py::object tree_unflatten_none( + py::object tree, + const std::vector& values, + int index = 0) { + return tree_map(tree, [&](py::handle obj) { + if (py::isinstance(obj)) { + return py::cast(values[index++]); + } else { + return py::cast(obj); + } + }); +} + auto validate_argnums_argnames( const std::optional& argnums, const StrOrVec& argnames) { @@ -437,6 +449,58 @@ auto py_vmap( }; } +std::unordered_map& tree_cache() { + // This map is used to Cache the tree structure of the outputs + static std::unordered_map tree_cache_; + return tree_cache_; +} + +struct PyCompiledFun { + py::function fun; + size_t fun_id; + + PyCompiledFun(const py::function& fun) + : fun(fun), fun_id(reinterpret_cast(fun.ptr())) {} + + PyCompiledFun(const PyCompiledFun&) = delete; + PyCompiledFun& operator=(const PyCompiledFun&) = delete; + PyCompiledFun& operator=(PyCompiledFun&& other) = delete; + PyCompiledFun(PyCompiledFun&& other) + : fun(std::move(other.fun)), fun_id(reinterpret_cast(fun.ptr())) { + other.fun_id = 0; + }; + + py::object operator()(const py::args& args) { + auto compile_fun = [this, &args](const std::vector& a) { + // Call the python function + py::object py_outputs = this->fun(*tree_unflatten(args, a)); + + // Flatten the outputs + auto outputs = tree_flatten(py_outputs, true); + + py_outputs = + tree_map(py_outputs, [](const py::handle& x) { return py::none(); }); + tree_cache().insert({this->fun_id, py_outputs}); + return outputs; + }; + + // Inputs must be array or tree of arrays + auto inputs = tree_flatten(args, true); + + // Compile and call + auto outputs = detail::compile(compile_fun, fun_id)(inputs); + + // Put the outputs back in the container + py::object py_outputs = tree_cache().at(fun_id); + return tree_unflatten_none(py_outputs, outputs); + }; + + ~PyCompiledFun() { + tree_cache().erase(fun_id); + detail::compile_erase(fun_id); + } +}; + void init_transforms(py::module_& m) { py::options options; options.disable_function_signatures(); @@ -679,45 +743,6 @@ void init_transforms(py::module_& m) { Returns: function: The vectorized function. )pbdoc"); - m.def( - "simplify", - [](const py::args& args) { - std::vector arrays = tree_flatten(args); - simplify(arrays); - }, - R"pbdoc( - simplify(*args) -> None - - Simplify the graph that computes the arrays. - - Run a few fast graph simplification operations to reuse computation and - reduce memory consumption. This function is meant to be run every time - so its overhead should be small, approximately 1ms for a graph with a - few thousand nodes. - - .. code-block:: python - - import mlx.core as mx - - def foo(x): - y = x @ x - z = x @ x - return y + z - - x = mx.ones((10, 10)) - y = foo(x) - z = foo(x) - - # Computes the matmul twice - mx.eval(y) - - # Computes the matmul once - mx.simplify(z) - mx.eval(z) - - Args: - args: Any number of arrays and/or trees of arrays to be simplified. - )pbdoc"); m.def( "export_to_dot", [](py::object file, const py::args& args) { @@ -736,4 +761,46 @@ void init_transforms(py::module_& m) { } }, "file"_a); + m.def( + "compile", + [](const py::function& fun) { + return py::cpp_function(PyCompiledFun{fun}); + }, + "fun"_a, + R"pbdoc( + compile(fun: function) -> function + + Returns a compiled function which produces the same output as ``fun``. + + Args: + fun (function): A function which takes a variable number of + :class:`array` or trees of :class:`array` and returns + a variable number of :class:`array` or trees of :class:`array`. + + Returns: + function: A compiled function which has the same input arguments + as ``fun`` and returns the the same output(s). + )pbdoc"); + m.def( + "disable_compile", + &disable_compile, + R"pbdoc( + disable_compile() -> None + + Globally disable compilation. Setting the environment variable + ``MLX_DISABLE_COMPILE`` can also be used to disable compilation. + )pbdoc"); + m.def( + "enable_compile", + &enable_compile, + R"pbdoc( + enable_compiler() -> None + + Globally enable compilation. This will override the environment + variable ``MLX_DISABLE_COMPILE`` if set. + )pbdoc"); + + // Register static Python object cleanup before the interpreter exits + auto atexit = py::module_::import("atexit"); + atexit.attr("register")(py::cpp_function([]() { tree_cache().clear(); })); } diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py new file mode 100644 index 000000000..594dd2859 --- /dev/null +++ b/python/tests/test_compile.py @@ -0,0 +1,195 @@ +# Copyright © 2023-2024 Apple Inc. + +import io +import unittest + +import mlx.core as mx +import mlx_tests + + +class TestCompile(mlx_tests.MLXTestCase): + def test_simple_compile(self): + def fun(x, y): + return x + y + + compiled_fn = mx.compile(fun) + compiled_fn = mx.compile(fun) + x = mx.array(1.0) + y = mx.array(1.0) + out = compiled_fn(x, y) + self.assertEqual(out.item(), 2.0) + + # Try again + out = compiled_fn(x, y) + self.assertEqual(out.item(), 2.0) + + # Change sizes + x = mx.array([1.0, 2.0]) + out = compiled_fn(x, y) + self.assertTrue(mx.array_equal(out, mx.array([2.0, 3.0]))) + + y = mx.array([1.0, 2.0]) + out = compiled_fn(x, y) + self.assertTrue(mx.array_equal(out, mx.array([2.0, 4.0]))) + + # Change types + x = mx.array([1, 2], mx.int32) + y = mx.array([1, 2], mx.int32) + out = compiled_fn(x, y) + self.assertEqual(out.dtype, mx.int32) + self.assertTrue(mx.array_equal(out, mx.array([2, 4]))) + + def test_compile_grad(self): + def loss_fn(x): + return mx.exp(x).sum() + + grad_fn = mx.grad(loss_fn) + + x = mx.array([0.5, -0.5, 1.2]) + dfdx = grad_fn(x) + compile_grad_fn = mx.compile(grad_fn) + c_dfdx = grad_fn(x) + + self.assertTrue(mx.allclose(c_dfdx, dfdx)) + + # Run it again without calling compile + c_dfdx = compile_grad_fn(x) + self.assertTrue(mx.allclose(c_dfdx, dfdx)) + + # Run it again with calling compile + c_dfdx = mx.compile(grad_fn)(x) + self.assertTrue(mx.allclose(c_dfdx, dfdx)) + + # Value and grad + def loss_fn(x): + return mx.exp(x).sum(), mx.sin(x) + + val_and_grad_fn = mx.value_and_grad(loss_fn) + (loss, val), dfdx = val_and_grad_fn(x) + (c_loss, c_val), c_dfdx = mx.compile(val_and_grad_fn)(x) + + self.assertTrue(mx.allclose(c_dfdx, dfdx)) + self.assertTrue(mx.allclose(c_loss, loss)) + self.assertTrue(mx.allclose(c_val, val)) + + def test_compile_inputs_with_primitives(self): + x = mx.array([1, 2, 3]) + y = mx.array([1, 2, 3]) + for _ in range(5): + x = x + y + y = y + 1 + + def fun(x, y): + return x * y + + out = fun(x, y) + + x = mx.array([1, 2, 3]) + y = mx.array([1, 2, 3]) + for _ in range(5): + x = x + y + y = y + 1 + + c_out = mx.compile(fun)(x, y) + self.assertTrue(mx.array_equal(out, c_out)) + + # Try again + c_out = mx.compile(fun)(x, y) + self.assertTrue(mx.array_equal(out, c_out)) + + def test_compile_with_closure(self): + x = mx.array(1) + + def closure(y): + return x + y + + compiled = mx.compile(closure) + out = compiled(mx.array(1)) + self.assertEqual(out.item(), 2) + + # Try again + out = compiled(mx.array(1)) + self.assertEqual(out.item(), 2) + + # Change the shape of the enclosed variable + x = mx.array([1, 2]) + out = compiled(mx.array(1)) + + # We still get the original input (closures are not updated) + self.assertEqual(out.item(), 2) + + # Try with a tree of enclosed variables + x = {"a": mx.array(1), "b": mx.array(2)} + + def closure(y): + return x["a"] + y + x["b"] + + compiled = mx.compile(closure) + out = compiled(mx.array(1)) + self.assertEqual(out.item(), 4) + + # Change the shape of one input + x["a"] = mx.array([4, 5]) + out = compiled(mx.array(1)) + self.assertEqual(out.item(), 4) + + x["b"] = mx.array([-6, -8]) + out = compiled(mx.array(1)) + self.assertEqual(out.item(), 4) + + # Enclosed variable is not evaluated yet + x = mx.array(1) + x = x + x + + def closure(y): + return x + y + + compiled = mx.compile(closure) + out = compiled(mx.array(2)) + self.assertEqual(out.item(), 4) + + # And again + out = compiled(mx.array(2)) + self.assertEqual(out.item(), 4) + + def test_function_creates_array(self): + def fun(x): + return x + mx.array(1) + + cfun = mx.compile(fun) + out = cfun(mx.array(3)) + self.assertEqual(out.item(), 4) + + # And again + out = cfun(mx.array(3)) + self.assertEqual(out.item(), 4) + + def test_enable_disable(self): + def fun(x): + y = x + 1 + z = x + 1 + return y + z + + def count_prims(outputs): + buf = io.StringIO() + mx.export_to_dot(buf, outputs) + buf.seek(0) + return len([l for l in buf.read().split() if "label" in l]) + + x = mx.array(1.0) + cfun = mx.compile(fun) + n_compiled = count_prims(cfun(x)) + + # Check disabled + mx.disable_compile() + n_uncompiled = count_prims(cfun(x)) + self.assertTrue(n_compiled < n_uncompiled) + + # Check renabled + mx.enable_compile() + n_enable_compiled = count_prims(cfun(x)) + self.assertEqual(n_compiled, n_enable_compiled) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 838650c2a..f881d5792 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -20,11 +20,11 @@ target_sources(tests PRIVATE arg_reduce_tests.cpp autograd_tests.cpp blas_tests.cpp + compile_tests.cpp creations_tests.cpp device_tests.cpp eval_tests.cpp fft_tests.cpp - graph_optimize_tests.cpp load_tests.cpp ops_tests.cpp random_tests.cpp diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp new file mode 100644 index 000000000..3d66d6afd --- /dev/null +++ b/tests/compile_tests.cpp @@ -0,0 +1,213 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "doctest/doctest.h" + +#include "mlx/mlx.h" + +using namespace mlx::core; + +std::vector simple_fun(const std::vector& inputs) { + return std::vector{inputs[0] + inputs[1]}; +} + +TEST_CASE("test simple compile") { + auto compfn = compile(simple_fun); + auto out = compfn({array(1.0f), array(2.0f)})[0]; + CHECK_EQ(out.item(), 3.0f); + + out = compfn({array(1.0f), array(2.0f)})[0]; + CHECK_EQ(out.item(), 3.0f); + + // Change the shapes + out = compfn({array({1.0f, 2.0f}), array(2.0f)})[0]; + CHECK(array_equal(out, array({3.0f, 4.0f})).item()); + + out = compfn({array(2.0f), array({1.0f, 2.0f})})[0]; + CHECK(array_equal(out, array({3.0f, 4.0f})).item()); + + // Change the types + out = compfn({array(2, int32), array({1.0f, 2.0f})})[0]; + CHECK(array_equal(out, array({3.0f, 4.0f})).item()); + + out = compfn({array(2.0f), array({1, 2}, int32)})[0]; + CHECK(array_equal(out, array({3.0f, 4.0f})).item()); +} + +std::vector grad_fun(const std::vector& inputs) { + auto loss = [](std::vector ins) { return exp(ins[0] + ins[1]); }; + return grad(loss, {0, 1})(inputs); +} + +TEST_CASE("test compile with grad") { + auto x = array(1.0f); + auto y = array(1.0f); + auto grads_expected = grad_fun({x, y}); + auto grads_compile = compile(grad_fun)({x, y}); + CHECK_EQ(grads_compile[0].item(), grads_expected[0].item()); + CHECK_EQ(grads_compile[1].item(), grads_expected[1].item()); +} + +TEST_CASE("test compile inputs with primitive") { + auto [k1, k2] = random::split(random::key(0)); + auto x = random::uniform({5, 5}, k1); + auto y = random::uniform({5, 5}, k2); + auto expected = simple_fun({x, y})[0]; + + x = random::uniform({5, 5}, k1); + y = random::uniform({5, 5}, k2); + auto out = compile(simple_fun)({x, y})[0]; + CHECK(array_equal(expected, out).item()); + + // Same thing twice + out = compile(simple_fun)({x, y})[0]; + CHECK(array_equal(expected, out).item()); +} + +std::vector fun_creats_array(const std::vector& inputs) { + return {inputs[0] + array(1.0f)}; +} + +TEST_CASE("test compile with created array") { + auto cfun = compile(fun_creats_array); + auto out = cfun({array(2.0f)}); + CHECK_EQ(out[0].item(), 3.0f); + + // Try again + out = cfun({array(2.0f)}); + CHECK_EQ(out[0].item(), 3.0f); +} + +std::vector inner_fun(const std::vector& inputs) { + return {array(2) * inputs[0]}; +} + +std::vector outer_fun(const std::vector& inputs) { + auto x = inputs[0] + inputs[1]; + auto y = compile(inner_fun)({x})[0]; + return {x + y}; +} + +TEST_CASE("test nested compile") { + auto cfun = compile(outer_fun); + auto out = cfun({array(1), array(2)})[0]; + CHECK_EQ(out.item(), 9); + + // Try again + out = cfun({array(1), array(2)})[0]; + CHECK_EQ(out.item(), 9); +} + +TEST_CASE("test enable and disable compile") { + CHECK_THROWS(compile(nullptr)); + disable_compile(); + compile(nullptr); + enable_compile(); + CHECK_THROWS(compile(nullptr)); +} + +auto add_scalars(const std::vector&) { + auto a = array(-1.0f); + auto b = array(-1.0f); + return std::vector{abs(a), abs(b)}; +}; + +auto max_scalars(const std::vector&) { + auto a = array({-1.0f, 2.0f}); + auto b = maximum(a, array(0.0f)); + auto c = maximum(-a, array(0.0f)); + auto d = b + c; + return std::vector{b, c, d}; +}; + +TEST_CASE("test simplify scalars") { + { + auto cfun = compile(add_scalars); + auto out = cfun({}); + auto c = out[0]; + auto d = out[1]; + CHECK(c.inputs()[0].id() == d.inputs()[0].id()); + } + + { + auto a = array({-1.0f, 2.0f}); + auto out = compile(max_scalars)({a}); + auto b = out[0]; + auto c = out[1]; + auto d = out[2]; + CHECK(b.inputs()[1].id() == c.inputs()[1].id()); + } +} + +auto exp_two(const std::vector& inputs) { + auto a = inputs[0]; + return std::vector{exp(a) + exp(a)}; +}; + +TEST_CASE("test simplify") { + auto a = array({1.0f, 2.0f}); + auto b = compile(exp_two)({a})[0]; + CHECK(b.inputs()[0].id() == b.inputs()[1].id()); +} + +auto add_diff(const std::vector& inputs) { + auto a = inputs[0]; + return std::vector{cos(a) + sin(a)}; +}; + +TEST_CASE("test no simplify") { + auto a = array({1.0f, 2.0f}); + auto b = compile(add_diff)({a})[0]; + CHECK(b.inputs()[0].id() != b.inputs()[1].id()); +} + +auto multi_one(const std::vector&) { + auto a = array(1.0); + auto b = array(2.0); + auto c = divmod(a, b); + auto d = divmod(a, b); + auto e = c[0] + d[0]; + auto f = c[1] + d[1]; + return std::vector{e, f}; +} + +auto multi_two(const std::vector&) { + auto a = array(1.0); + auto b = array(1.0); + auto c = divmod(a, b); + return std::vector{c}; +} + +auto multi_three(const std::vector&) { + auto a = array(1.0); + auto b = array(2.0); + auto c = divmod(a, b); + auto d = divmod(a, b); + auto e = stack({c[0], c[1], d[0], d[1]}); + return std::vector{e}; +} + +TEST_CASE("test simplify multi output") { + { + auto out = compile(multi_one)({}); + auto e = out[0]; + auto f = out[1]; + CHECK_EQ(e.inputs()[0].id(), e.inputs()[1].id()); + CHECK_EQ(f.inputs()[0].id(), f.inputs()[1].id()); + } + + { + auto c = compile(multi_two)({}); + CHECK_EQ(c[0].inputs()[0].id(), c[0].inputs()[1].id()); + CHECK_EQ(c[0].inputs()[0].id(), c[1].inputs()[0].id()); + CHECK_EQ(c[1].inputs()[0].id(), c[1].inputs()[1].id()); + } + + // Make sure the output order of multi-output primitives + // is respected in simplification + { + auto e = compile(multi_three)({})[0]; + CHECK(array_equal(e, array({0.0f, 1.0f, 0.0f, 1.0f})).item()); + CHECK_EQ(e.inputs()[0].id(), e.inputs()[2].id()); + CHECK_EQ(e.inputs()[1].id(), e.inputs()[3].id()); + } +} diff --git a/tests/graph_optimize_tests.cpp b/tests/graph_optimize_tests.cpp deleted file mode 100644 index d30005fc6..000000000 --- a/tests/graph_optimize_tests.cpp +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright © 2023 Apple Inc. - -#include "doctest/doctest.h" - -#include "mlx/mlx.h" - -using namespace mlx::core; - -TEST_CASE("test simplify scalars") { - { - auto a = array(-1.0f); - auto b = array(-1.0f); - auto c = abs(a); - auto d = abs(b); - simplify({c, d}); - CHECK(c.inputs()[0].id() == d.inputs()[0].id()); - } - - { - auto a = array({-1.0f, 2.0f}); - auto b = maximum(a, array(0.0f)); - auto c = maximum(-a, array(0.0f)); - auto d = b + c; - simplify({d}); - CHECK(b.inputs()[1].id() == c.inputs()[1].id()); - } -} - -TEST_CASE("test simplify") { - auto a = array({1.0f, 2.0f}); - auto b = exp(a) + exp(a); - simplify(b); - CHECK(b.inputs()[0].id() == b.inputs()[1].id()); -} - -TEST_CASE("test no simplify") { - auto a = array({1.0f, 2.0f}); - auto b = cos(a) + sin(a); - simplify(b); - CHECK(b.inputs()[0].id() != b.inputs()[1].id()); -} - -TEST_CASE("test simplify multi output") { - { - auto a = array(1.0); - auto b = array(2.0); - auto c = divmod(a, b); - auto d = divmod(a, b); - auto e = c[0] + d[0]; - auto f = c[1] + d[1]; - - simplify({e, f}); - CHECK_EQ(e.inputs()[0].id(), e.inputs()[1].id()); - CHECK_EQ(f.inputs()[0].id(), f.inputs()[1].id()); - } - - { - auto a = array(1.0); - auto b = array(1.0); - auto c = divmod(a, b); - simplify(c); - CHECK_EQ(c[0].inputs()[0].id(), c[0].inputs()[1].id()); - CHECK_EQ(c[0].inputs()[0].id(), c[1].inputs()[0].id()); - CHECK_EQ(c[1].inputs()[0].id(), c[1].inputs()[1].id()); - } - - // Make sure the output order of multi-output primitives - // is respected in simplification - { - auto a = array(1.0); - auto b = array(2.0); - auto c = divmod(a, b); - auto d = divmod(a, b); - auto e = stack({c[0], c[1], d[0], d[1]}); - simplify(e); - CHECK(array_equal(e, array({0.0f, 1.0f, 0.0f, 1.0f})).item()); - CHECK_EQ(e.inputs()[0].id(), e.inputs()[2].id()); - CHECK_EQ(e.inputs()[1].id(), e.inputs()[3].id()); - } -} diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 927ad2874..9e5f9e277 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -511,6 +511,7 @@ TEST_CASE("test is inf") { CHECK_FALSE(isinf(x).item()); auto inf = std::numeric_limits::infinity(); + array y(inf); CHECK(isinf(y).item());