mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
remove simplify
This commit is contained in:
parent
5c78c16f1c
commit
1c3f82ca17
@ -35,169 +35,6 @@ class Synchronizer : public Primitive {
|
|||||||
// are currently under a function transformation.
|
// are currently under a function transformation.
|
||||||
int detail::InTracing::tracing_counter{0};
|
int detail::InTracing::tracing_counter{0};
|
||||||
|
|
||||||
void simplify(const std::vector<array>& outputs) {
|
|
||||||
// Some notes about how this function works
|
|
||||||
//
|
|
||||||
// Step 1: Traverse the graph and build a tape. During the graph
|
|
||||||
// traversal we:
|
|
||||||
// - Build a map of inputs to their parents.
|
|
||||||
// - Record scalar inputs in a map in order to fuse them.
|
|
||||||
// Step 2: Process the tape. A node in the tape has inputs and outputs.
|
|
||||||
// - Scalar inputs are replaced with their canonical scalar
|
|
||||||
// - We check each inputs output nodes. Every output node that matches
|
|
||||||
// the current node gets fused into the current node.
|
|
||||||
std::function<void(const array&)> recurse;
|
|
||||||
std::queue<array> tape;
|
|
||||||
std::unordered_set<std::uintptr_t> cache;
|
|
||||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
|
|
||||||
parents_map;
|
|
||||||
|
|
||||||
// Helpers to identify identical scalars
|
|
||||||
std::map<std::pair<uint64_t, Dtype::Val>, array> scalars;
|
|
||||||
auto is_scalar = [](const array& a) {
|
|
||||||
return a.is_evaled() && a.ndim() == 0;
|
|
||||||
};
|
|
||||||
auto get_scalar_rep = [](const array& a) {
|
|
||||||
uint64_t v = 0;
|
|
||||||
int dtype;
|
|
||||||
switch (a.dtype().size) {
|
|
||||||
case 1:
|
|
||||||
v = *a.data<uint8_t>();
|
|
||||||
break;
|
|
||||||
case 4:
|
|
||||||
v = *a.data<uint32_t>();
|
|
||||||
break;
|
|
||||||
case 8:
|
|
||||||
v = *a.data<uint64_t>();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
return std::make_pair(v, a.dtype().val);
|
|
||||||
};
|
|
||||||
|
|
||||||
// DFS the graph to build the tape, and log parents and scalars
|
|
||||||
recurse = [&](const array& a) {
|
|
||||||
auto id = a.id();
|
|
||||||
if (cache.find(id) != cache.end()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
for (int i = 0; i < a.inputs().size(); i++) {
|
|
||||||
auto& in = a.inputs()[i];
|
|
||||||
parents_map[in.id()].push_back({a, i});
|
|
||||||
for (auto& s : a.siblings()) {
|
|
||||||
parents_map[in.id()].push_back({s, i});
|
|
||||||
}
|
|
||||||
recurse(in);
|
|
||||||
}
|
|
||||||
cache.insert(id);
|
|
||||||
for (auto& s : a.siblings()) {
|
|
||||||
cache.insert(s.id());
|
|
||||||
}
|
|
||||||
|
|
||||||
tape.push(a);
|
|
||||||
if (is_scalar(a)) {
|
|
||||||
scalars.insert({get_scalar_rep(a), a});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
for (auto& a : outputs) {
|
|
||||||
recurse(a);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper that fuses two arrays in the graph by setting the parents of the
|
|
||||||
// source to point to the destination
|
|
||||||
auto fuse = [&](array& dst, array& src) {
|
|
||||||
// Canonicalize the order of the primitives outputs
|
|
||||||
auto sources = src.outputs();
|
|
||||||
auto dests = dst.outputs();
|
|
||||||
// For each src parent, point it to the corresponding dest
|
|
||||||
for (int i = 0; i < sources.size(); ++i) {
|
|
||||||
auto src_parents = parents_map.find(sources[i].id());
|
|
||||||
if (src_parents == parents_map.end()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto& pairs = parents_map[dests[i].id()];
|
|
||||||
for (auto& parent : src_parents->second) {
|
|
||||||
parent.first.inputs()[parent.second] = dests[i];
|
|
||||||
pairs.push_back(parent);
|
|
||||||
}
|
|
||||||
// Remove the source from the map to avoid fusing with it again
|
|
||||||
parents_map.erase(src_parents);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Depth-1 array equivalence check.
|
|
||||||
auto array_equivalent = [](const array& a, const array& b) {
|
|
||||||
if (!a.has_primitive() || !b.has_primitive()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (a.primitive_id() == b.primitive_id()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
const auto& pa = a.primitive();
|
|
||||||
const auto& pb = b.primitive();
|
|
||||||
if (typeid(pa) != typeid(pb)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (a.inputs().size() != b.inputs().size()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < a.inputs().size(); i++) {
|
|
||||||
if (a.inputs()[i].id() != b.inputs()[i].id()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return pa.is_equivalent(pb);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Walk the graph
|
|
||||||
while (!tape.empty()) {
|
|
||||||
auto arr = std::move(tape.front());
|
|
||||||
tape.pop();
|
|
||||||
|
|
||||||
// Check if we can fuse scalars
|
|
||||||
if (is_scalar(arr)) {
|
|
||||||
auto scalar = scalars.find(get_scalar_rep(arr));
|
|
||||||
if (scalar->second.id() != arr.id()) {
|
|
||||||
fuse(scalar->second, arr);
|
|
||||||
arr = scalar->second;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper to check if we can fuse the parents of the
|
|
||||||
// given array
|
|
||||||
auto maybe_fuse_parents = [&](auto& a) {
|
|
||||||
auto parents = parents_map.find(a.id());
|
|
||||||
if (parents != parents_map.end()) {
|
|
||||||
auto N = parents->second.size();
|
|
||||||
std::vector<bool> mask(N, false);
|
|
||||||
for (int i = 0; i < N; i++) {
|
|
||||||
if (mask[i]) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
for (int j = i + 1; j < N; j++) {
|
|
||||||
if (mask[j]) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto& src = parents->second[j].first;
|
|
||||||
auto& dst = parents->second[i].first;
|
|
||||||
if (src.id() != dst.id() && array_equivalent(src, dst)) {
|
|
||||||
fuse(dst, src);
|
|
||||||
mask[j] = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
maybe_fuse_parents(arr);
|
|
||||||
for (auto& s : arr.siblings()) {
|
|
||||||
maybe_fuse_parents(s);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void eval(const std::vector<array>& outputs) {
|
void eval(const std::vector<array>& outputs) {
|
||||||
std::function<void(const array&)> recurse;
|
std::function<void(const array&)> recurse;
|
||||||
std::queue<array> tape;
|
std::queue<array> tape;
|
||||||
|
@ -21,14 +21,6 @@ void disable_compiler();
|
|||||||
*/
|
*/
|
||||||
void enable_compiler();
|
void enable_compiler();
|
||||||
|
|
||||||
/** Fuse equivalent arrays to avoid duplicate execution. */
|
|
||||||
void simplify(const std::vector<array>& outputs);
|
|
||||||
|
|
||||||
template <typename... Arrays>
|
|
||||||
void simplify(Arrays... outputs) {
|
|
||||||
simplify(std::vector<array>{std::forward<Arrays>(outputs)...});
|
|
||||||
}
|
|
||||||
|
|
||||||
void eval(const std::vector<array>& outputs);
|
void eval(const std::vector<array>& outputs);
|
||||||
|
|
||||||
template <typename... Arrays>
|
template <typename... Arrays>
|
||||||
|
@ -777,45 +777,6 @@ void init_transforms(py::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
function: The vectorized function.
|
function: The vectorized function.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
|
||||||
"simplify",
|
|
||||||
[](const py::args& args) {
|
|
||||||
std::vector<array> arrays = tree_flatten(args);
|
|
||||||
simplify(arrays);
|
|
||||||
},
|
|
||||||
R"pbdoc(
|
|
||||||
simplify(*args) -> None
|
|
||||||
|
|
||||||
Simplify the graph that computes the arrays.
|
|
||||||
|
|
||||||
Run a few fast graph simplification operations to reuse computation and
|
|
||||||
reduce memory consumption. This function is meant to be run every time
|
|
||||||
so its overhead should be small, approximately 1ms for a graph with a
|
|
||||||
few thousand nodes.
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
|
|
||||||
def foo(x):
|
|
||||||
y = x @ x
|
|
||||||
z = x @ x
|
|
||||||
return y + z
|
|
||||||
|
|
||||||
x = mx.ones((10, 10))
|
|
||||||
y = foo(x)
|
|
||||||
z = foo(x)
|
|
||||||
|
|
||||||
# Computes the matmul twice
|
|
||||||
mx.eval(y)
|
|
||||||
|
|
||||||
# Computes the matmul once
|
|
||||||
mx.simplify(z)
|
|
||||||
mx.eval(z)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
args: Any number of arrays and/or trees of arrays to be simplified.
|
|
||||||
)pbdoc");
|
|
||||||
m.def(
|
m.def(
|
||||||
"export_to_dot",
|
"export_to_dot",
|
||||||
[](py::object file, const py::args& args) {
|
[](py::object file, const py::args& args) {
|
||||||
|
@ -25,7 +25,6 @@ target_sources(tests PRIVATE
|
|||||||
device_tests.cpp
|
device_tests.cpp
|
||||||
eval_tests.cpp
|
eval_tests.cpp
|
||||||
fft_tests.cpp
|
fft_tests.cpp
|
||||||
graph_optimize_tests.cpp
|
|
||||||
load_tests.cpp
|
load_tests.cpp
|
||||||
ops_tests.cpp
|
ops_tests.cpp
|
||||||
random_tests.cpp
|
random_tests.cpp
|
||||||
|
@ -104,3 +104,77 @@ TEST_CASE("test enable and disable compile") {
|
|||||||
enable_compiler();
|
enable_compiler();
|
||||||
CHECK_THROWS(compile(nullptr));
|
CHECK_THROWS(compile(nullptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test simplify scalars") {
|
||||||
|
{
|
||||||
|
auto a = array(-1.0f);
|
||||||
|
auto b = array(-1.0f);
|
||||||
|
auto c = abs(a);
|
||||||
|
auto d = abs(b);
|
||||||
|
simplify({c, d});
|
||||||
|
CHECK(c.inputs()[0].id() == d.inputs()[0].id());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto a = array({-1.0f, 2.0f});
|
||||||
|
auto b = maximum(a, array(0.0f));
|
||||||
|
auto c = maximum(-a, array(0.0f));
|
||||||
|
auto d = b + c;
|
||||||
|
simplify({d});
|
||||||
|
CHECK(b.inputs()[1].id() == c.inputs()[1].id());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO rework these tests for compile
|
||||||
|
/*TEST_CASE("test simplify") {
|
||||||
|
auto a = array({1.0f, 2.0f});
|
||||||
|
auto b = exp(a) + exp(a);
|
||||||
|
simplify(b);
|
||||||
|
CHECK(b.inputs()[0].id() == b.inputs()[1].id());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test no simplify") {
|
||||||
|
auto a = array({1.0f, 2.0f});
|
||||||
|
auto b = cos(a) + sin(a);
|
||||||
|
simplify(b);
|
||||||
|
CHECK(b.inputs()[0].id() != b.inputs()[1].id());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test simplify multi output") {
|
||||||
|
{
|
||||||
|
auto a = array(1.0);
|
||||||
|
auto b = array(2.0);
|
||||||
|
auto c = divmod(a, b);
|
||||||
|
auto d = divmod(a, b);
|
||||||
|
auto e = c[0] + d[0];
|
||||||
|
auto f = c[1] + d[1];
|
||||||
|
|
||||||
|
simplify({e, f});
|
||||||
|
CHECK_EQ(e.inputs()[0].id(), e.inputs()[1].id());
|
||||||
|
CHECK_EQ(f.inputs()[0].id(), f.inputs()[1].id());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto a = array(1.0);
|
||||||
|
auto b = array(1.0);
|
||||||
|
auto c = divmod(a, b);
|
||||||
|
simplify(c);
|
||||||
|
CHECK_EQ(c[0].inputs()[0].id(), c[0].inputs()[1].id());
|
||||||
|
CHECK_EQ(c[0].inputs()[0].id(), c[1].inputs()[0].id());
|
||||||
|
CHECK_EQ(c[1].inputs()[0].id(), c[1].inputs()[1].id());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure the output order of multi-output primitives
|
||||||
|
// is respected in simplification
|
||||||
|
{
|
||||||
|
auto a = array(1.0);
|
||||||
|
auto b = array(2.0);
|
||||||
|
auto c = divmod(a, b);
|
||||||
|
auto d = divmod(a, b);
|
||||||
|
auto e = stack({c[0], c[1], d[0], d[1]});
|
||||||
|
simplify(e);
|
||||||
|
CHECK(array_equal(e, array({0.0f, 1.0f, 0.0f, 1.0f})).item<bool>());
|
||||||
|
CHECK_EQ(e.inputs()[0].id(), e.inputs()[2].id());
|
||||||
|
CHECK_EQ(e.inputs()[1].id(), e.inputs()[3].id());
|
||||||
|
}
|
||||||
|
}*/
|
||||||
|
@ -1,80 +0,0 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
#include "doctest/doctest.h"
|
|
||||||
|
|
||||||
#include "mlx/mlx.h"
|
|
||||||
|
|
||||||
using namespace mlx::core;
|
|
||||||
|
|
||||||
TEST_CASE("test simplify scalars") {
|
|
||||||
{
|
|
||||||
auto a = array(-1.0f);
|
|
||||||
auto b = array(-1.0f);
|
|
||||||
auto c = abs(a);
|
|
||||||
auto d = abs(b);
|
|
||||||
simplify({c, d});
|
|
||||||
CHECK(c.inputs()[0].id() == d.inputs()[0].id());
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
auto a = array({-1.0f, 2.0f});
|
|
||||||
auto b = maximum(a, array(0.0f));
|
|
||||||
auto c = maximum(-a, array(0.0f));
|
|
||||||
auto d = b + c;
|
|
||||||
simplify({d});
|
|
||||||
CHECK(b.inputs()[1].id() == c.inputs()[1].id());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_CASE("test simplify") {
|
|
||||||
auto a = array({1.0f, 2.0f});
|
|
||||||
auto b = exp(a) + exp(a);
|
|
||||||
simplify(b);
|
|
||||||
CHECK(b.inputs()[0].id() == b.inputs()[1].id());
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_CASE("test no simplify") {
|
|
||||||
auto a = array({1.0f, 2.0f});
|
|
||||||
auto b = cos(a) + sin(a);
|
|
||||||
simplify(b);
|
|
||||||
CHECK(b.inputs()[0].id() != b.inputs()[1].id());
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_CASE("test simplify multi output") {
|
|
||||||
{
|
|
||||||
auto a = array(1.0);
|
|
||||||
auto b = array(2.0);
|
|
||||||
auto c = divmod(a, b);
|
|
||||||
auto d = divmod(a, b);
|
|
||||||
auto e = c[0] + d[0];
|
|
||||||
auto f = c[1] + d[1];
|
|
||||||
|
|
||||||
simplify({e, f});
|
|
||||||
CHECK_EQ(e.inputs()[0].id(), e.inputs()[1].id());
|
|
||||||
CHECK_EQ(f.inputs()[0].id(), f.inputs()[1].id());
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
auto a = array(1.0);
|
|
||||||
auto b = array(1.0);
|
|
||||||
auto c = divmod(a, b);
|
|
||||||
simplify(c);
|
|
||||||
CHECK_EQ(c[0].inputs()[0].id(), c[0].inputs()[1].id());
|
|
||||||
CHECK_EQ(c[0].inputs()[0].id(), c[1].inputs()[0].id());
|
|
||||||
CHECK_EQ(c[1].inputs()[0].id(), c[1].inputs()[1].id());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure the output order of multi-output primitives
|
|
||||||
// is respected in simplification
|
|
||||||
{
|
|
||||||
auto a = array(1.0);
|
|
||||||
auto b = array(2.0);
|
|
||||||
auto c = divmod(a, b);
|
|
||||||
auto d = divmod(a, b);
|
|
||||||
auto e = stack({c[0], c[1], d[0], d[1]});
|
|
||||||
simplify(e);
|
|
||||||
CHECK(array_equal(e, array({0.0f, 1.0f, 0.0f, 1.0f})).item<bool>());
|
|
||||||
CHECK_EQ(e.inputs()[0].id(), e.inputs()[2].id());
|
|
||||||
CHECK_EQ(e.inputs()[1].id(), e.inputs()[3].id());
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user