mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compile front-end (#476)
* fix tests for linux * make a move on compile * basic compile scaffold works * compile binding * clean * fix * fix grad, more tests * basic python tests * fix segfault on python exit * compile works with python closures * fix test * fix python globals bug, and erase * simplify * more cpp tests * bug fix with move function and compile at exit * simplify inputs also * enable and disable compiler * remove simplify * simplify tests use compile now * fix multi-output with compile * clear output tree from cache when function goes out of scope * ../python/src/transforms.cpp * remove closure capture * comments
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <algorithm>
|
||||
#include <future>
|
||||
#include <map>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
@@ -35,169 +34,6 @@ class Synchronizer : public Primitive {
|
||||
// are currently under a function transformation.
|
||||
int detail::InTracing::tracing_counter{0};
|
||||
|
||||
void simplify(const std::vector<array>& outputs) {
|
||||
// Some notes about how this function works
|
||||
//
|
||||
// Step 1: Traverse the graph and build a tape. During the graph
|
||||
// traversal we:
|
||||
// - Build a map of inputs to their parents.
|
||||
// - Record scalar inputs in a map in order to fuse them.
|
||||
// Step 2: Process the tape. A node in the tape has inputs and outputs.
|
||||
// - Scalar inputs are replaced with their canonical scalar
|
||||
// - We check each inputs output nodes. Every output node that matches
|
||||
// the current node gets fused into the current node.
|
||||
std::function<void(const array&)> recurse;
|
||||
std::queue<array> tape;
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
|
||||
parents_map;
|
||||
|
||||
// Helpers to identify identical scalars
|
||||
std::map<std::pair<uint64_t, Dtype::Val>, array> scalars;
|
||||
auto is_scalar = [](const array& a) {
|
||||
return a.is_evaled() && a.ndim() == 0;
|
||||
};
|
||||
auto get_scalar_rep = [](const array& a) {
|
||||
uint64_t v = 0;
|
||||
int dtype;
|
||||
switch (a.dtype().size) {
|
||||
case 1:
|
||||
v = *a.data<uint8_t>();
|
||||
break;
|
||||
case 4:
|
||||
v = *a.data<uint32_t>();
|
||||
break;
|
||||
case 8:
|
||||
v = *a.data<uint64_t>();
|
||||
break;
|
||||
}
|
||||
return std::make_pair(v, a.dtype().val);
|
||||
};
|
||||
|
||||
// DFS the graph to build the tape, and log parents and scalars
|
||||
recurse = [&](const array& a) {
|
||||
auto id = a.id();
|
||||
if (cache.find(id) != cache.end()) {
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < a.inputs().size(); i++) {
|
||||
auto& in = a.inputs()[i];
|
||||
parents_map[in.id()].push_back({a, i});
|
||||
for (auto& s : a.siblings()) {
|
||||
parents_map[in.id()].push_back({s, i});
|
||||
}
|
||||
recurse(in);
|
||||
}
|
||||
cache.insert(id);
|
||||
for (auto& s : a.siblings()) {
|
||||
cache.insert(s.id());
|
||||
}
|
||||
|
||||
tape.push(a);
|
||||
if (is_scalar(a)) {
|
||||
scalars.insert({get_scalar_rep(a), a});
|
||||
}
|
||||
};
|
||||
for (auto& a : outputs) {
|
||||
recurse(a);
|
||||
}
|
||||
|
||||
// Helper that fuses two arrays in the graph by setting the parents of the
|
||||
// source to point to the destination
|
||||
auto fuse = [&](array& dst, array& src) {
|
||||
// Canonicalize the order of the primitives outputs
|
||||
auto sources = src.outputs();
|
||||
auto dests = dst.outputs();
|
||||
// For each src parent, point it to the corresponding dest
|
||||
for (int i = 0; i < sources.size(); ++i) {
|
||||
auto src_parents = parents_map.find(sources[i].id());
|
||||
if (src_parents == parents_map.end()) {
|
||||
continue;
|
||||
}
|
||||
auto& pairs = parents_map[dests[i].id()];
|
||||
for (auto& parent : src_parents->second) {
|
||||
parent.first.inputs()[parent.second] = dests[i];
|
||||
pairs.push_back(parent);
|
||||
}
|
||||
// Remove the source from the map to avoid fusing with it again
|
||||
parents_map.erase(src_parents);
|
||||
}
|
||||
};
|
||||
|
||||
// Depth-1 array equivalence check.
|
||||
auto array_equivalent = [](const array& a, const array& b) {
|
||||
if (!a.has_primitive() || !b.has_primitive()) {
|
||||
return false;
|
||||
}
|
||||
if (a.primitive_id() == b.primitive_id()) {
|
||||
return false;
|
||||
}
|
||||
const auto& pa = a.primitive();
|
||||
const auto& pb = b.primitive();
|
||||
if (typeid(pa) != typeid(pb)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (a.inputs().size() != b.inputs().size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = 0; i < a.inputs().size(); i++) {
|
||||
if (a.inputs()[i].id() != b.inputs()[i].id()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return pa.is_equivalent(pb);
|
||||
};
|
||||
|
||||
// Walk the graph
|
||||
while (!tape.empty()) {
|
||||
auto arr = std::move(tape.front());
|
||||
tape.pop();
|
||||
|
||||
// Check if we can fuse scalars
|
||||
if (is_scalar(arr)) {
|
||||
auto scalar = scalars.find(get_scalar_rep(arr));
|
||||
if (scalar->second.id() != arr.id()) {
|
||||
fuse(scalar->second, arr);
|
||||
arr = scalar->second;
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to check if we can fuse the parents of the
|
||||
// given array
|
||||
auto maybe_fuse_parents = [&](auto& a) {
|
||||
auto parents = parents_map.find(a.id());
|
||||
if (parents != parents_map.end()) {
|
||||
auto N = parents->second.size();
|
||||
std::vector<bool> mask(N, false);
|
||||
for (int i = 0; i < N; i++) {
|
||||
if (mask[i]) {
|
||||
continue;
|
||||
}
|
||||
for (int j = i + 1; j < N; j++) {
|
||||
if (mask[j]) {
|
||||
continue;
|
||||
}
|
||||
auto& src = parents->second[j].first;
|
||||
auto& dst = parents->second[i].first;
|
||||
if (src.id() != dst.id() && array_equivalent(src, dst)) {
|
||||
fuse(dst, src);
|
||||
mask[j] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
maybe_fuse_parents(arr);
|
||||
for (auto& s : arr.siblings()) {
|
||||
maybe_fuse_parents(s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void eval(const std::vector<array>& outputs) {
|
||||
std::function<void(const array&)> recurse;
|
||||
std::queue<array> tape;
|
||||
|
||||
Reference in New Issue
Block a user