From 1c3f82ca17beea5f05e43bcde0c3655c3836a06f Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 16 Jan 2024 20:14:27 -0800 Subject: [PATCH] remove simplify --- mlx/transforms.cpp | 163 --------------------------------- mlx/transforms.h | 8 -- python/src/transforms.cpp | 39 -------- tests/CMakeLists.txt | 1 - tests/compile_tests.cpp | 74 +++++++++++++++ tests/graph_optimize_tests.cpp | 80 ---------------- 6 files changed, 74 insertions(+), 291 deletions(-) delete mode 100644 tests/graph_optimize_tests.cpp diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 08dd92e0a..d0bd0de74 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -35,169 +35,6 @@ class Synchronizer : public Primitive { // are currently under a function transformation. int detail::InTracing::tracing_counter{0}; -void simplify(const std::vector& outputs) { - // Some notes about how this function works - // - // Step 1: Traverse the graph and build a tape. During the graph - // traversal we: - // - Build a map of inputs to their parents. - // - Record scalar inputs in a map in order to fuse them. - // Step 2: Process the tape. A node in the tape has inputs and outputs. - // - Scalar inputs are replaced with their canonical scalar - // - We check each inputs output nodes. Every output node that matches - // the current node gets fused into the current node. - std::function recurse; - std::queue tape; - std::unordered_set cache; - std::unordered_map>> - parents_map; - - // Helpers to identify identical scalars - std::map, array> scalars; - auto is_scalar = [](const array& a) { - return a.is_evaled() && a.ndim() == 0; - }; - auto get_scalar_rep = [](const array& a) { - uint64_t v = 0; - int dtype; - switch (a.dtype().size) { - case 1: - v = *a.data(); - break; - case 4: - v = *a.data(); - break; - case 8: - v = *a.data(); - break; - } - return std::make_pair(v, a.dtype().val); - }; - - // DFS the graph to build the tape, and log parents and scalars - recurse = [&](const array& a) { - auto id = a.id(); - if (cache.find(id) != cache.end()) { - return; - } - for (int i = 0; i < a.inputs().size(); i++) { - auto& in = a.inputs()[i]; - parents_map[in.id()].push_back({a, i}); - for (auto& s : a.siblings()) { - parents_map[in.id()].push_back({s, i}); - } - recurse(in); - } - cache.insert(id); - for (auto& s : a.siblings()) { - cache.insert(s.id()); - } - - tape.push(a); - if (is_scalar(a)) { - scalars.insert({get_scalar_rep(a), a}); - } - }; - for (auto& a : outputs) { - recurse(a); - } - - // Helper that fuses two arrays in the graph by setting the parents of the - // source to point to the destination - auto fuse = [&](array& dst, array& src) { - // Canonicalize the order of the primitives outputs - auto sources = src.outputs(); - auto dests = dst.outputs(); - // For each src parent, point it to the corresponding dest - for (int i = 0; i < sources.size(); ++i) { - auto src_parents = parents_map.find(sources[i].id()); - if (src_parents == parents_map.end()) { - continue; - } - auto& pairs = parents_map[dests[i].id()]; - for (auto& parent : src_parents->second) { - parent.first.inputs()[parent.second] = dests[i]; - pairs.push_back(parent); - } - // Remove the source from the map to avoid fusing with it again - parents_map.erase(src_parents); - } - }; - - // Depth-1 array equivalence check. - auto array_equivalent = [](const array& a, const array& b) { - if (!a.has_primitive() || !b.has_primitive()) { - return false; - } - if (a.primitive_id() == b.primitive_id()) { - return false; - } - const auto& pa = a.primitive(); - const auto& pb = b.primitive(); - if (typeid(pa) != typeid(pb)) { - return false; - } - - if (a.inputs().size() != b.inputs().size()) { - return false; - } - - for (int i = 0; i < a.inputs().size(); i++) { - if (a.inputs()[i].id() != b.inputs()[i].id()) { - return false; - } - } - - return pa.is_equivalent(pb); - }; - - // Walk the graph - while (!tape.empty()) { - auto arr = std::move(tape.front()); - tape.pop(); - - // Check if we can fuse scalars - if (is_scalar(arr)) { - auto scalar = scalars.find(get_scalar_rep(arr)); - if (scalar->second.id() != arr.id()) { - fuse(scalar->second, arr); - arr = scalar->second; - } - } - - // Helper to check if we can fuse the parents of the - // given array - auto maybe_fuse_parents = [&](auto& a) { - auto parents = parents_map.find(a.id()); - if (parents != parents_map.end()) { - auto N = parents->second.size(); - std::vector mask(N, false); - for (int i = 0; i < N; i++) { - if (mask[i]) { - continue; - } - for (int j = i + 1; j < N; j++) { - if (mask[j]) { - continue; - } - auto& src = parents->second[j].first; - auto& dst = parents->second[i].first; - if (src.id() != dst.id() && array_equivalent(src, dst)) { - fuse(dst, src); - mask[j] = true; - } - } - } - } - }; - - maybe_fuse_parents(arr); - for (auto& s : arr.siblings()) { - maybe_fuse_parents(s); - } - } -} - void eval(const std::vector& outputs) { std::function recurse; std::queue tape; diff --git a/mlx/transforms.h b/mlx/transforms.h index 340c40bc5..b380229ec 100644 --- a/mlx/transforms.h +++ b/mlx/transforms.h @@ -21,14 +21,6 @@ void disable_compiler(); */ void enable_compiler(); -/** Fuse equivalent arrays to avoid duplicate execution. */ -void simplify(const std::vector& outputs); - -template -void simplify(Arrays... outputs) { - simplify(std::vector{std::forward(outputs)...}); -} - void eval(const std::vector& outputs); template diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 30553e263..d196feb55 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -777,45 +777,6 @@ void init_transforms(py::module_& m) { Returns: function: The vectorized function. )pbdoc"); - m.def( - "simplify", - [](const py::args& args) { - std::vector arrays = tree_flatten(args); - simplify(arrays); - }, - R"pbdoc( - simplify(*args) -> None - - Simplify the graph that computes the arrays. - - Run a few fast graph simplification operations to reuse computation and - reduce memory consumption. This function is meant to be run every time - so its overhead should be small, approximately 1ms for a graph with a - few thousand nodes. - - .. code-block:: python - - import mlx.core as mx - - def foo(x): - y = x @ x - z = x @ x - return y + z - - x = mx.ones((10, 10)) - y = foo(x) - z = foo(x) - - # Computes the matmul twice - mx.eval(y) - - # Computes the matmul once - mx.simplify(z) - mx.eval(z) - - Args: - args: Any number of arrays and/or trees of arrays to be simplified. - )pbdoc"); m.def( "export_to_dot", [](py::object file, const py::args& args) { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index f120b2ee7..3f5b40864 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -25,7 +25,6 @@ target_sources(tests PRIVATE device_tests.cpp eval_tests.cpp fft_tests.cpp - graph_optimize_tests.cpp load_tests.cpp ops_tests.cpp random_tests.cpp diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index a6252e701..30ac16dd1 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -104,3 +104,77 @@ TEST_CASE("test enable and disable compile") { enable_compiler(); CHECK_THROWS(compile(nullptr)); } + +TEST_CASE("test simplify scalars") { + { + auto a = array(-1.0f); + auto b = array(-1.0f); + auto c = abs(a); + auto d = abs(b); + simplify({c, d}); + CHECK(c.inputs()[0].id() == d.inputs()[0].id()); + } + + { + auto a = array({-1.0f, 2.0f}); + auto b = maximum(a, array(0.0f)); + auto c = maximum(-a, array(0.0f)); + auto d = b + c; + simplify({d}); + CHECK(b.inputs()[1].id() == c.inputs()[1].id()); + } +} + +// TODO rework these tests for compile +/*TEST_CASE("test simplify") { + auto a = array({1.0f, 2.0f}); + auto b = exp(a) + exp(a); + simplify(b); + CHECK(b.inputs()[0].id() == b.inputs()[1].id()); +} + +TEST_CASE("test no simplify") { + auto a = array({1.0f, 2.0f}); + auto b = cos(a) + sin(a); + simplify(b); + CHECK(b.inputs()[0].id() != b.inputs()[1].id()); +} + +TEST_CASE("test simplify multi output") { + { + auto a = array(1.0); + auto b = array(2.0); + auto c = divmod(a, b); + auto d = divmod(a, b); + auto e = c[0] + d[0]; + auto f = c[1] + d[1]; + + simplify({e, f}); + CHECK_EQ(e.inputs()[0].id(), e.inputs()[1].id()); + CHECK_EQ(f.inputs()[0].id(), f.inputs()[1].id()); + } + + { + auto a = array(1.0); + auto b = array(1.0); + auto c = divmod(a, b); + simplify(c); + CHECK_EQ(c[0].inputs()[0].id(), c[0].inputs()[1].id()); + CHECK_EQ(c[0].inputs()[0].id(), c[1].inputs()[0].id()); + CHECK_EQ(c[1].inputs()[0].id(), c[1].inputs()[1].id()); + } + + // Make sure the output order of multi-output primitives + // is respected in simplification + { + auto a = array(1.0); + auto b = array(2.0); + auto c = divmod(a, b); + auto d = divmod(a, b); + auto e = stack({c[0], c[1], d[0], d[1]}); + simplify(e); + CHECK(array_equal(e, array({0.0f, 1.0f, 0.0f, 1.0f})).item()); + CHECK_EQ(e.inputs()[0].id(), e.inputs()[2].id()); + CHECK_EQ(e.inputs()[1].id(), e.inputs()[3].id()); + } +}*/ diff --git a/tests/graph_optimize_tests.cpp b/tests/graph_optimize_tests.cpp deleted file mode 100644 index d30005fc6..000000000 --- a/tests/graph_optimize_tests.cpp +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright © 2023 Apple Inc. - -#include "doctest/doctest.h" - -#include "mlx/mlx.h" - -using namespace mlx::core; - -TEST_CASE("test simplify scalars") { - { - auto a = array(-1.0f); - auto b = array(-1.0f); - auto c = abs(a); - auto d = abs(b); - simplify({c, d}); - CHECK(c.inputs()[0].id() == d.inputs()[0].id()); - } - - { - auto a = array({-1.0f, 2.0f}); - auto b = maximum(a, array(0.0f)); - auto c = maximum(-a, array(0.0f)); - auto d = b + c; - simplify({d}); - CHECK(b.inputs()[1].id() == c.inputs()[1].id()); - } -} - -TEST_CASE("test simplify") { - auto a = array({1.0f, 2.0f}); - auto b = exp(a) + exp(a); - simplify(b); - CHECK(b.inputs()[0].id() == b.inputs()[1].id()); -} - -TEST_CASE("test no simplify") { - auto a = array({1.0f, 2.0f}); - auto b = cos(a) + sin(a); - simplify(b); - CHECK(b.inputs()[0].id() != b.inputs()[1].id()); -} - -TEST_CASE("test simplify multi output") { - { - auto a = array(1.0); - auto b = array(2.0); - auto c = divmod(a, b); - auto d = divmod(a, b); - auto e = c[0] + d[0]; - auto f = c[1] + d[1]; - - simplify({e, f}); - CHECK_EQ(e.inputs()[0].id(), e.inputs()[1].id()); - CHECK_EQ(f.inputs()[0].id(), f.inputs()[1].id()); - } - - { - auto a = array(1.0); - auto b = array(1.0); - auto c = divmod(a, b); - simplify(c); - CHECK_EQ(c[0].inputs()[0].id(), c[0].inputs()[1].id()); - CHECK_EQ(c[0].inputs()[0].id(), c[1].inputs()[0].id()); - CHECK_EQ(c[1].inputs()[0].id(), c[1].inputs()[1].id()); - } - - // Make sure the output order of multi-output primitives - // is respected in simplification - { - auto a = array(1.0); - auto b = array(2.0); - auto c = divmod(a, b); - auto d = divmod(a, b); - auto e = stack({c[0], c[1], d[0], d[1]}); - simplify(e); - CHECK(array_equal(e, array({0.0f, 1.0f, 0.0f, 1.0f})).item()); - CHECK_EQ(e.inputs()[0].id(), e.inputs()[2].id()); - CHECK_EQ(e.inputs()[1].id(), e.inputs()[3].id()); - } -}