mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	remove simplify
This commit is contained in:
		@@ -35,169 +35,6 @@ class Synchronizer : public Primitive {
 | 
				
			|||||||
// are currently under a function transformation.
 | 
					// are currently under a function transformation.
 | 
				
			||||||
int detail::InTracing::tracing_counter{0};
 | 
					int detail::InTracing::tracing_counter{0};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void simplify(const std::vector<array>& outputs) {
 | 
					 | 
				
			||||||
  // Some notes about how this function works
 | 
					 | 
				
			||||||
  //
 | 
					 | 
				
			||||||
  // Step 1: Traverse the graph and build a tape. During the graph
 | 
					 | 
				
			||||||
  // traversal we:
 | 
					 | 
				
			||||||
  //      - Build a map of inputs to their parents.
 | 
					 | 
				
			||||||
  //      - Record scalar inputs in a map in order to fuse them.
 | 
					 | 
				
			||||||
  // Step 2: Process the tape. A node in the tape has inputs and outputs.
 | 
					 | 
				
			||||||
  //      - Scalar inputs are replaced with their canonical scalar
 | 
					 | 
				
			||||||
  //      - We check each inputs output nodes. Every output node that matches
 | 
					 | 
				
			||||||
  //        the current node gets fused into the current node.
 | 
					 | 
				
			||||||
  std::function<void(const array&)> recurse;
 | 
					 | 
				
			||||||
  std::queue<array> tape;
 | 
					 | 
				
			||||||
  std::unordered_set<std::uintptr_t> cache;
 | 
					 | 
				
			||||||
  std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
 | 
					 | 
				
			||||||
      parents_map;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Helpers to identify identical scalars
 | 
					 | 
				
			||||||
  std::map<std::pair<uint64_t, Dtype::Val>, array> scalars;
 | 
					 | 
				
			||||||
  auto is_scalar = [](const array& a) {
 | 
					 | 
				
			||||||
    return a.is_evaled() && a.ndim() == 0;
 | 
					 | 
				
			||||||
  };
 | 
					 | 
				
			||||||
  auto get_scalar_rep = [](const array& a) {
 | 
					 | 
				
			||||||
    uint64_t v = 0;
 | 
					 | 
				
			||||||
    int dtype;
 | 
					 | 
				
			||||||
    switch (a.dtype().size) {
 | 
					 | 
				
			||||||
      case 1:
 | 
					 | 
				
			||||||
        v = *a.data<uint8_t>();
 | 
					 | 
				
			||||||
        break;
 | 
					 | 
				
			||||||
      case 4:
 | 
					 | 
				
			||||||
        v = *a.data<uint32_t>();
 | 
					 | 
				
			||||||
        break;
 | 
					 | 
				
			||||||
      case 8:
 | 
					 | 
				
			||||||
        v = *a.data<uint64_t>();
 | 
					 | 
				
			||||||
        break;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    return std::make_pair(v, a.dtype().val);
 | 
					 | 
				
			||||||
  };
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // DFS the graph to build the tape, and log parents and scalars
 | 
					 | 
				
			||||||
  recurse = [&](const array& a) {
 | 
					 | 
				
			||||||
    auto id = a.id();
 | 
					 | 
				
			||||||
    if (cache.find(id) != cache.end()) {
 | 
					 | 
				
			||||||
      return;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    for (int i = 0; i < a.inputs().size(); i++) {
 | 
					 | 
				
			||||||
      auto& in = a.inputs()[i];
 | 
					 | 
				
			||||||
      parents_map[in.id()].push_back({a, i});
 | 
					 | 
				
			||||||
      for (auto& s : a.siblings()) {
 | 
					 | 
				
			||||||
        parents_map[in.id()].push_back({s, i});
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
      recurse(in);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    cache.insert(id);
 | 
					 | 
				
			||||||
    for (auto& s : a.siblings()) {
 | 
					 | 
				
			||||||
      cache.insert(s.id());
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    tape.push(a);
 | 
					 | 
				
			||||||
    if (is_scalar(a)) {
 | 
					 | 
				
			||||||
      scalars.insert({get_scalar_rep(a), a});
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  };
 | 
					 | 
				
			||||||
  for (auto& a : outputs) {
 | 
					 | 
				
			||||||
    recurse(a);
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Helper that fuses two arrays in the graph by setting the parents of the
 | 
					 | 
				
			||||||
  // source to point to the destination
 | 
					 | 
				
			||||||
  auto fuse = [&](array& dst, array& src) {
 | 
					 | 
				
			||||||
    // Canonicalize the order of the primitives outputs
 | 
					 | 
				
			||||||
    auto sources = src.outputs();
 | 
					 | 
				
			||||||
    auto dests = dst.outputs();
 | 
					 | 
				
			||||||
    // For each src parent, point it to the corresponding dest
 | 
					 | 
				
			||||||
    for (int i = 0; i < sources.size(); ++i) {
 | 
					 | 
				
			||||||
      auto src_parents = parents_map.find(sources[i].id());
 | 
					 | 
				
			||||||
      if (src_parents == parents_map.end()) {
 | 
					 | 
				
			||||||
        continue;
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
      auto& pairs = parents_map[dests[i].id()];
 | 
					 | 
				
			||||||
      for (auto& parent : src_parents->second) {
 | 
					 | 
				
			||||||
        parent.first.inputs()[parent.second] = dests[i];
 | 
					 | 
				
			||||||
        pairs.push_back(parent);
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
      // Remove the source from the map to avoid fusing with it again
 | 
					 | 
				
			||||||
      parents_map.erase(src_parents);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  };
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Depth-1 array equivalence check.
 | 
					 | 
				
			||||||
  auto array_equivalent = [](const array& a, const array& b) {
 | 
					 | 
				
			||||||
    if (!a.has_primitive() || !b.has_primitive()) {
 | 
					 | 
				
			||||||
      return false;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    if (a.primitive_id() == b.primitive_id()) {
 | 
					 | 
				
			||||||
      return false;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    const auto& pa = a.primitive();
 | 
					 | 
				
			||||||
    const auto& pb = b.primitive();
 | 
					 | 
				
			||||||
    if (typeid(pa) != typeid(pb)) {
 | 
					 | 
				
			||||||
      return false;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if (a.inputs().size() != b.inputs().size()) {
 | 
					 | 
				
			||||||
      return false;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    for (int i = 0; i < a.inputs().size(); i++) {
 | 
					 | 
				
			||||||
      if (a.inputs()[i].id() != b.inputs()[i].id()) {
 | 
					 | 
				
			||||||
        return false;
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return pa.is_equivalent(pb);
 | 
					 | 
				
			||||||
  };
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Walk the graph
 | 
					 | 
				
			||||||
  while (!tape.empty()) {
 | 
					 | 
				
			||||||
    auto arr = std::move(tape.front());
 | 
					 | 
				
			||||||
    tape.pop();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // Check if we can fuse scalars
 | 
					 | 
				
			||||||
    if (is_scalar(arr)) {
 | 
					 | 
				
			||||||
      auto scalar = scalars.find(get_scalar_rep(arr));
 | 
					 | 
				
			||||||
      if (scalar->second.id() != arr.id()) {
 | 
					 | 
				
			||||||
        fuse(scalar->second, arr);
 | 
					 | 
				
			||||||
        arr = scalar->second;
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // Helper to check if we can fuse the parents of the
 | 
					 | 
				
			||||||
    // given array
 | 
					 | 
				
			||||||
    auto maybe_fuse_parents = [&](auto& a) {
 | 
					 | 
				
			||||||
      auto parents = parents_map.find(a.id());
 | 
					 | 
				
			||||||
      if (parents != parents_map.end()) {
 | 
					 | 
				
			||||||
        auto N = parents->second.size();
 | 
					 | 
				
			||||||
        std::vector<bool> mask(N, false);
 | 
					 | 
				
			||||||
        for (int i = 0; i < N; i++) {
 | 
					 | 
				
			||||||
          if (mask[i]) {
 | 
					 | 
				
			||||||
            continue;
 | 
					 | 
				
			||||||
          }
 | 
					 | 
				
			||||||
          for (int j = i + 1; j < N; j++) {
 | 
					 | 
				
			||||||
            if (mask[j]) {
 | 
					 | 
				
			||||||
              continue;
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
            auto& src = parents->second[j].first;
 | 
					 | 
				
			||||||
            auto& dst = parents->second[i].first;
 | 
					 | 
				
			||||||
            if (src.id() != dst.id() && array_equivalent(src, dst)) {
 | 
					 | 
				
			||||||
              fuse(dst, src);
 | 
					 | 
				
			||||||
              mask[j] = true;
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
          }
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    };
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    maybe_fuse_parents(arr);
 | 
					 | 
				
			||||||
    for (auto& s : arr.siblings()) {
 | 
					 | 
				
			||||||
      maybe_fuse_parents(s);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
void eval(const std::vector<array>& outputs) {
 | 
					void eval(const std::vector<array>& outputs) {
 | 
				
			||||||
  std::function<void(const array&)> recurse;
 | 
					  std::function<void(const array&)> recurse;
 | 
				
			||||||
  std::queue<array> tape;
 | 
					  std::queue<array> tape;
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -21,14 +21,6 @@ void disable_compiler();
 | 
				
			|||||||
 */
 | 
					 */
 | 
				
			||||||
void enable_compiler();
 | 
					void enable_compiler();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/** Fuse equivalent arrays to avoid duplicate execution. */
 | 
					 | 
				
			||||||
void simplify(const std::vector<array>& outputs);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
template <typename... Arrays>
 | 
					 | 
				
			||||||
void simplify(Arrays... outputs) {
 | 
					 | 
				
			||||||
  simplify(std::vector<array>{std::forward<Arrays>(outputs)...});
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
void eval(const std::vector<array>& outputs);
 | 
					void eval(const std::vector<array>& outputs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename... Arrays>
 | 
					template <typename... Arrays>
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -777,45 +777,6 @@ void init_transforms(py::module_& m) {
 | 
				
			|||||||
        Returns:
 | 
					        Returns:
 | 
				
			||||||
            function: The vectorized function.
 | 
					            function: The vectorized function.
 | 
				
			||||||
      )pbdoc");
 | 
					      )pbdoc");
 | 
				
			||||||
  m.def(
 | 
					 | 
				
			||||||
      "simplify",
 | 
					 | 
				
			||||||
      [](const py::args& args) {
 | 
					 | 
				
			||||||
        std::vector<array> arrays = tree_flatten(args);
 | 
					 | 
				
			||||||
        simplify(arrays);
 | 
					 | 
				
			||||||
      },
 | 
					 | 
				
			||||||
      R"pbdoc(
 | 
					 | 
				
			||||||
        simplify(*args) -> None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Simplify the graph that computes the arrays.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Run a few fast graph simplification operations to reuse computation and
 | 
					 | 
				
			||||||
        reduce memory consumption. This function is meant to be run every time
 | 
					 | 
				
			||||||
        so its overhead should be small, approximately 1ms for a graph with a
 | 
					 | 
				
			||||||
        few thousand nodes.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        .. code-block:: python
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
          import mlx.core as mx
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
          def foo(x):
 | 
					 | 
				
			||||||
            y = x @ x
 | 
					 | 
				
			||||||
            z = x @ x
 | 
					 | 
				
			||||||
            return y + z
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
          x = mx.ones((10, 10))
 | 
					 | 
				
			||||||
          y = foo(x)
 | 
					 | 
				
			||||||
          z = foo(x)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
          # Computes the matmul twice
 | 
					 | 
				
			||||||
          mx.eval(y)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
          # Computes the matmul once
 | 
					 | 
				
			||||||
          mx.simplify(z)
 | 
					 | 
				
			||||||
          mx.eval(z)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
          args: Any number of arrays and/or trees of arrays to be simplified.
 | 
					 | 
				
			||||||
      )pbdoc");
 | 
					 | 
				
			||||||
  m.def(
 | 
					  m.def(
 | 
				
			||||||
      "export_to_dot",
 | 
					      "export_to_dot",
 | 
				
			||||||
      [](py::object file, const py::args& args) {
 | 
					      [](py::object file, const py::args& args) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -25,7 +25,6 @@ target_sources(tests PRIVATE
 | 
				
			|||||||
  device_tests.cpp
 | 
					  device_tests.cpp
 | 
				
			||||||
  eval_tests.cpp
 | 
					  eval_tests.cpp
 | 
				
			||||||
  fft_tests.cpp
 | 
					  fft_tests.cpp
 | 
				
			||||||
  graph_optimize_tests.cpp
 | 
					 | 
				
			||||||
  load_tests.cpp
 | 
					  load_tests.cpp
 | 
				
			||||||
  ops_tests.cpp
 | 
					  ops_tests.cpp
 | 
				
			||||||
  random_tests.cpp
 | 
					  random_tests.cpp
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -104,3 +104,77 @@ TEST_CASE("test enable and disable compile") {
 | 
				
			|||||||
  enable_compiler();
 | 
					  enable_compiler();
 | 
				
			||||||
  CHECK_THROWS(compile(nullptr));
 | 
					  CHECK_THROWS(compile(nullptr));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST_CASE("test simplify scalars") {
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					    auto a = array(-1.0f);
 | 
				
			||||||
 | 
					    auto b = array(-1.0f);
 | 
				
			||||||
 | 
					    auto c = abs(a);
 | 
				
			||||||
 | 
					    auto d = abs(b);
 | 
				
			||||||
 | 
					    simplify({c, d});
 | 
				
			||||||
 | 
					    CHECK(c.inputs()[0].id() == d.inputs()[0].id());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					    auto a = array({-1.0f, 2.0f});
 | 
				
			||||||
 | 
					    auto b = maximum(a, array(0.0f));
 | 
				
			||||||
 | 
					    auto c = maximum(-a, array(0.0f));
 | 
				
			||||||
 | 
					    auto d = b + c;
 | 
				
			||||||
 | 
					    simplify({d});
 | 
				
			||||||
 | 
					    CHECK(b.inputs()[1].id() == c.inputs()[1].id());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// TODO rework these tests for compile
 | 
				
			||||||
 | 
					/*TEST_CASE("test simplify") {
 | 
				
			||||||
 | 
					  auto a = array({1.0f, 2.0f});
 | 
				
			||||||
 | 
					  auto b = exp(a) + exp(a);
 | 
				
			||||||
 | 
					  simplify(b);
 | 
				
			||||||
 | 
					  CHECK(b.inputs()[0].id() == b.inputs()[1].id());
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST_CASE("test no simplify") {
 | 
				
			||||||
 | 
					  auto a = array({1.0f, 2.0f});
 | 
				
			||||||
 | 
					  auto b = cos(a) + sin(a);
 | 
				
			||||||
 | 
					  simplify(b);
 | 
				
			||||||
 | 
					  CHECK(b.inputs()[0].id() != b.inputs()[1].id());
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST_CASE("test simplify multi output") {
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					    auto a = array(1.0);
 | 
				
			||||||
 | 
					    auto b = array(2.0);
 | 
				
			||||||
 | 
					    auto c = divmod(a, b);
 | 
				
			||||||
 | 
					    auto d = divmod(a, b);
 | 
				
			||||||
 | 
					    auto e = c[0] + d[0];
 | 
				
			||||||
 | 
					    auto f = c[1] + d[1];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    simplify({e, f});
 | 
				
			||||||
 | 
					    CHECK_EQ(e.inputs()[0].id(), e.inputs()[1].id());
 | 
				
			||||||
 | 
					    CHECK_EQ(f.inputs()[0].id(), f.inputs()[1].id());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					    auto a = array(1.0);
 | 
				
			||||||
 | 
					    auto b = array(1.0);
 | 
				
			||||||
 | 
					    auto c = divmod(a, b);
 | 
				
			||||||
 | 
					    simplify(c);
 | 
				
			||||||
 | 
					    CHECK_EQ(c[0].inputs()[0].id(), c[0].inputs()[1].id());
 | 
				
			||||||
 | 
					    CHECK_EQ(c[0].inputs()[0].id(), c[1].inputs()[0].id());
 | 
				
			||||||
 | 
					    CHECK_EQ(c[1].inputs()[0].id(), c[1].inputs()[1].id());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Make sure the output order of multi-output primitives
 | 
				
			||||||
 | 
					  // is respected in simplification
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					    auto a = array(1.0);
 | 
				
			||||||
 | 
					    auto b = array(2.0);
 | 
				
			||||||
 | 
					    auto c = divmod(a, b);
 | 
				
			||||||
 | 
					    auto d = divmod(a, b);
 | 
				
			||||||
 | 
					    auto e = stack({c[0], c[1], d[0], d[1]});
 | 
				
			||||||
 | 
					    simplify(e);
 | 
				
			||||||
 | 
					    CHECK(array_equal(e, array({0.0f, 1.0f, 0.0f, 1.0f})).item<bool>());
 | 
				
			||||||
 | 
					    CHECK_EQ(e.inputs()[0].id(), e.inputs()[2].id());
 | 
				
			||||||
 | 
					    CHECK_EQ(e.inputs()[1].id(), e.inputs()[3].id());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}*/
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,80 +0,0 @@
 | 
				
			|||||||
// Copyright © 2023 Apple Inc.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#include "doctest/doctest.h"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#include "mlx/mlx.h"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
using namespace mlx::core;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
TEST_CASE("test simplify scalars") {
 | 
					 | 
				
			||||||
  {
 | 
					 | 
				
			||||||
    auto a = array(-1.0f);
 | 
					 | 
				
			||||||
    auto b = array(-1.0f);
 | 
					 | 
				
			||||||
    auto c = abs(a);
 | 
					 | 
				
			||||||
    auto d = abs(b);
 | 
					 | 
				
			||||||
    simplify({c, d});
 | 
					 | 
				
			||||||
    CHECK(c.inputs()[0].id() == d.inputs()[0].id());
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  {
 | 
					 | 
				
			||||||
    auto a = array({-1.0f, 2.0f});
 | 
					 | 
				
			||||||
    auto b = maximum(a, array(0.0f));
 | 
					 | 
				
			||||||
    auto c = maximum(-a, array(0.0f));
 | 
					 | 
				
			||||||
    auto d = b + c;
 | 
					 | 
				
			||||||
    simplify({d});
 | 
					 | 
				
			||||||
    CHECK(b.inputs()[1].id() == c.inputs()[1].id());
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
TEST_CASE("test simplify") {
 | 
					 | 
				
			||||||
  auto a = array({1.0f, 2.0f});
 | 
					 | 
				
			||||||
  auto b = exp(a) + exp(a);
 | 
					 | 
				
			||||||
  simplify(b);
 | 
					 | 
				
			||||||
  CHECK(b.inputs()[0].id() == b.inputs()[1].id());
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
TEST_CASE("test no simplify") {
 | 
					 | 
				
			||||||
  auto a = array({1.0f, 2.0f});
 | 
					 | 
				
			||||||
  auto b = cos(a) + sin(a);
 | 
					 | 
				
			||||||
  simplify(b);
 | 
					 | 
				
			||||||
  CHECK(b.inputs()[0].id() != b.inputs()[1].id());
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
TEST_CASE("test simplify multi output") {
 | 
					 | 
				
			||||||
  {
 | 
					 | 
				
			||||||
    auto a = array(1.0);
 | 
					 | 
				
			||||||
    auto b = array(2.0);
 | 
					 | 
				
			||||||
    auto c = divmod(a, b);
 | 
					 | 
				
			||||||
    auto d = divmod(a, b);
 | 
					 | 
				
			||||||
    auto e = c[0] + d[0];
 | 
					 | 
				
			||||||
    auto f = c[1] + d[1];
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    simplify({e, f});
 | 
					 | 
				
			||||||
    CHECK_EQ(e.inputs()[0].id(), e.inputs()[1].id());
 | 
					 | 
				
			||||||
    CHECK_EQ(f.inputs()[0].id(), f.inputs()[1].id());
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  {
 | 
					 | 
				
			||||||
    auto a = array(1.0);
 | 
					 | 
				
			||||||
    auto b = array(1.0);
 | 
					 | 
				
			||||||
    auto c = divmod(a, b);
 | 
					 | 
				
			||||||
    simplify(c);
 | 
					 | 
				
			||||||
    CHECK_EQ(c[0].inputs()[0].id(), c[0].inputs()[1].id());
 | 
					 | 
				
			||||||
    CHECK_EQ(c[0].inputs()[0].id(), c[1].inputs()[0].id());
 | 
					 | 
				
			||||||
    CHECK_EQ(c[1].inputs()[0].id(), c[1].inputs()[1].id());
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Make sure the output order of multi-output primitives
 | 
					 | 
				
			||||||
  // is respected in simplification
 | 
					 | 
				
			||||||
  {
 | 
					 | 
				
			||||||
    auto a = array(1.0);
 | 
					 | 
				
			||||||
    auto b = array(2.0);
 | 
					 | 
				
			||||||
    auto c = divmod(a, b);
 | 
					 | 
				
			||||||
    auto d = divmod(a, b);
 | 
					 | 
				
			||||||
    auto e = stack({c[0], c[1], d[0], d[1]});
 | 
					 | 
				
			||||||
    simplify(e);
 | 
					 | 
				
			||||||
    CHECK(array_equal(e, array({0.0f, 1.0f, 0.0f, 1.0f})).item<bool>());
 | 
					 | 
				
			||||||
    CHECK_EQ(e.inputs()[0].id(), e.inputs()[2].id());
 | 
					 | 
				
			||||||
    CHECK_EQ(e.inputs()[1].id(), e.inputs()[3].id());
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
		Reference in New Issue
	
	Block a user