From da8c88578442680ca541fc6aa9d84a067319b890 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 9 Jan 2025 11:23:19 -0800 Subject: [PATCH] Simplify removes no-ops from the tape (#1759) * simplify removes no-ops from the tape * comment --- mlx/backend/common/compiled.h | 4 +--- mlx/compile.cpp | 32 +++++++++++++++++++++++++++----- mlx/compile_impl.h | 5 ++--- tests/compile_tests.cpp | 11 +++++++++++ 4 files changed, 41 insertions(+), 11 deletions(-) diff --git a/mlx/backend/common/compiled.h b/mlx/backend/common/compiled.h index 72959fdc9..c79d852a0 100644 --- a/mlx/backend/common/compiled.h +++ b/mlx/backend/common/compiled.h @@ -11,9 +11,7 @@ namespace mlx::core { inline bool is_static_cast(const Primitive& p) { - return ( - typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) || - typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType)); + return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType)); } std::string build_lib_name( diff --git a/mlx/compile.cpp b/mlx/compile.cpp index f0ca6a679..45fc1c1aa 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -68,8 +68,7 @@ bool is_reduction(const Primitive& p) { } bool is_fusable(const Primitive& p) { - return is_unary(p) || is_binary(p) || is_ternary(p) || is_broadcast(p) || - is_noop(p); + return is_unary(p) || is_binary(p) || is_ternary(p) || is_broadcast(p); } Compiled::Compiled( @@ -351,12 +350,12 @@ std::pair, ParentsMap> compile_dfs( return {tape, parents_map}; } -// Simplify the tape. Note, this function modifies in-place both the tape and -// the parents map to remove orphaned arrays +// Simplify the tape. Note, this function modifies in-place both the tape, +// the parents map to remove orphaned arrays, and potentially the outputs void compile_simplify( std::vector& tape, ParentsMap& parents_map, - const std::vector& outputs, + std::vector& outputs, int passes) { // Helpers to identify identical scalars std::map, array> scalars; @@ -433,6 +432,28 @@ void compile_simplify( } tape = std::move(new_tape); + // Remove no-ops + { + std::unordered_map output_map; + for (auto& o : outputs) { + output_map.insert({o.id(), o}); + } + for (auto& arr : tape) { + if (!arr.has_primitive() || !is_noop(arr.primitive())) { + new_tape.push_back(std::move(arr)); + continue; + } + merge_one(arr.inputs()[0], arr, parents_map); + if (auto it = output_map.find(arr.id()); it != output_map.end()) { + it->second = arr.inputs()[0]; + } + } + tape = std::move(new_tape); + for (auto& o : outputs) { + o = output_map.at(o.id()); + } + } + std::unordered_map tape_order; for (uint32_t i = 0; i < tape.size(); ++i) { tape_order.insert({tape[i].id(), i}); @@ -442,6 +463,7 @@ void compile_simplify( for (auto& o : outputs) { output_set.insert(o.id()); } + // Multi-pass merge only keeping non-orphaned arrays in the tape for (int pass = 0; pass < passes; ++pass) { for (auto& arr : tape) { diff --git a/mlx/compile_impl.h b/mlx/compile_impl.h index 7e9235f36..bc3f44b4e 100644 --- a/mlx/compile_impl.h +++ b/mlx/compile_impl.h @@ -39,12 +39,11 @@ std::pair, ParentsMap> compile_dfs( const std::vector& outputs, const std::vector& original_inputs); -// Simplify the tape. Note, this function modifies in-place both the tape and -// the parents map to remove orphaned arrays +// Simplify the tape. void compile_simplify( std::vector& tape, ParentsMap& parents_map, - const std::vector& outputs, + std::vector& outputs, int passes); std::vector compile_replace( diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index 5026c8505..66511682d 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -157,6 +157,17 @@ TEST_CASE("test simplify") { set_compile_mode(CompileMode::enabled); } +TEST_CASE("test simplify noops") { + set_compile_mode(CompileMode::no_fuse); + auto a = array({1.0f, 2.0f}); + auto fun = [](const std::vector& inputs) -> std::vector { + return {copy(stop_gradient(exp(stop_gradient(inputs[0]))))}; + }; + auto b = compile(fun)({a})[0]; + CHECK(b.inputs()[0].id() == a.id()); + set_compile_mode(CompileMode::enabled); +} + auto add_diff(const std::vector& inputs) { auto a = inputs[0]; return std::vector{cos(a) + sin(a)};